@@ -56,7 +56,7 @@ def test_compile():
5656 """
5757
5858 kernel_name = "vector_add"
59- kernel_sources = KernelSource (kernel_name , kernel_string , "cuda " )
59+ kernel_sources = KernelSource (kernel_name , kernel_string , "HIP " )
6060 kernel_instance = KernelInstance (kernel_name , kernel_sources , kernel_string , [], None , None , dict (), [])
6161 dev = kt_hip .HipFunctions (0 )
6262 try :
@@ -71,11 +71,11 @@ def test_memset_and_memcpy_dtoh():
7171 x = np .array (a ).astype (np .int8 )
7272 x_d = hip .hipMalloc (x .nbytes )
7373
74- Hipfunc = kt_hip .HipFunctions ()
75- Hipfunc .memset (x_d , 4 , x .nbytes )
74+ dev = kt_hip .HipFunctions ()
75+ dev .memset (x_d , 4 , x .nbytes )
7676
7777 output = np .empty (4 , dtype = np .int8 )
78- Hipfunc .memcpy_dtoh (output , x_d )
78+ dev .memcpy_dtoh (output , x_d )
7979
8080 assert all (output == np .full (4 , 4 ))
8181
@@ -86,12 +86,45 @@ def test_memcpy_htod():
8686 x_d = hip .hipMalloc (x .nbytes )
8787 output = np .empty (4 , dtype = np .float32 )
8888
89- Hipfunc = kt_hip .HipFunctions ()
90- Hipfunc .memcpy_htod (x_d , x )
91- Hipfunc .memcpy_dtoh (output , x_d )
89+ dev = kt_hip .HipFunctions ()
90+ dev .memcpy_htod (x_d , x )
91+ dev .memcpy_dtoh (output , x_d )
9292
9393 assert all (output == x )
9494
95+ @skip_if_no_pyhip
96+ def test_copy_constant_memory_args ():
97+ kernel_string = """
98+ __constant__ float my_constant_data[100];
99+ __global__ void copy_data_kernel(float* output) {
100+ int idx = threadIdx.x + blockIdx.x * blockDim.x;
101+ if (idx < 100) {
102+ output[idx] = my_constant_data[idx];
103+ }
104+ }
105+ """
106+
107+ kernel_name = "copy_data_kernel"
108+ kernel_sources = KernelSource (kernel_name , kernel_string , "HIP" )
109+ kernel_instance = KernelInstance (kernel_name , kernel_sources , kernel_string , [], None , None , dict (), [])
110+ dev = kt_hip .HipFunctions (0 )
111+ kernel = dev .compile (kernel_instance )
112+
113+ my_constant_data = np .full (100 , 23 ).astype (np .float32 )
114+ cmem_args = {'my_constant_data' : my_constant_data }
115+ dev .copy_constant_memory_args (cmem_args )
116+
117+ output = np .full (100 , 0 ).astype (np .float32 )
118+ gpu_args = dev .ready_argument_list ([output ])
119+
120+ threads = (100 , 1 , 1 )
121+ grid = (1 , 1 , 1 )
122+ dev .run_kernel (kernel , gpu_args , threads , grid )
123+
124+ dev .memcpy_dtoh (output , gpu_args .field0 )
125+
126+ assert (my_constant_data == output ).all ()
127+
95128def dummy_func (a , b , block = 0 , grid = 0 , stream = None , shared = 0 , texrefs = None ):
96129 pass
97130
0 commit comments