3131import triton
3232import triton .language as tl
3333
34+ DEVICE = triton .runtime .driver .active .get_active_torch_device ()
35+
3436
3537@triton .autotune (
3638 configs = [
@@ -141,7 +143,6 @@ def grouped_matmul_kernel(
141143
142144
143145def group_gemm_fn (group_A , group_B ):
144- device = torch .device ('cuda' )
145146 assert len (group_A ) == len (group_B )
146147 group_size = len (group_A )
147148
@@ -157,7 +158,7 @@ def group_gemm_fn(group_A, group_B):
157158 assert A .shape [1 ] == B .shape [0 ]
158159 M , K = A .shape
159160 K , N = B .shape
160- C = torch .empty ((M , N ), device = device , dtype = A .dtype )
161+ C = torch .empty ((M , N ), device = DEVICE , dtype = A .dtype )
161162 group_C .append (C )
162163 A_addrs .append (A .data_ptr ())
163164 B_addrs .append (B .data_ptr ())
@@ -166,11 +167,11 @@ def group_gemm_fn(group_A, group_B):
166167 g_lds += [A .stride (0 ), B .stride (0 ), C .stride (0 )]
167168
168169 # note these are device tensors
169- d_a_ptrs = torch .tensor (A_addrs , device = device )
170- d_b_ptrs = torch .tensor (B_addrs , device = device )
171- d_c_ptrs = torch .tensor (C_addrs , device = device )
172- d_g_sizes = torch .tensor (g_sizes , dtype = torch .int32 , device = device )
173- d_g_lds = torch .tensor (g_lds , dtype = torch .int32 , device = device )
170+ d_a_ptrs = torch .tensor (A_addrs , device = DEVICE )
171+ d_b_ptrs = torch .tensor (B_addrs , device = DEVICE )
172+ d_c_ptrs = torch .tensor (C_addrs , device = DEVICE )
173+ d_g_sizes = torch .tensor (g_sizes , dtype = torch .int32 , device = DEVICE )
174+ d_g_lds = torch .tensor (g_lds , dtype = torch .int32 , device = DEVICE )
174175 # we use a fixed number of CTA, and it's auto-tunable
175176 grid = lambda META : (META ['NUM_SM' ], )
176177 grouped_matmul_kernel [grid ](
@@ -197,8 +198,8 @@ def group_gemm_fn(group_A, group_B):
197198 M = group_m [i ]
198199 N = group_n [i ]
199200 K = group_k [i ]
200- A = torch .rand ((M , K ), device = "cuda" , dtype = torch .float16 )
201- B = torch .rand ((K , N ), device = "cuda" , dtype = torch .float16 )
201+ A = torch .rand ((M , K ), device = DEVICE , dtype = torch .float16 )
202+ B = torch .rand ((K , N ), device = DEVICE , dtype = torch .float16 )
202203 group_A .append (A )
203204 group_B .append (B )
204205
@@ -255,9 +256,9 @@ def benchmark(N, provider):
255256 g_lds = []
256257 group_C = []
257258 for i in range (group_size ):
258- A = torch .rand ((N , N ), device = "cuda" , dtype = torch .float16 )
259- B = torch .rand ((N , N ), device = "cuda" , dtype = torch .float16 )
260- C = torch .empty ((N , N ), device = "cuda" , dtype = torch .float16 )
259+ A = torch .rand ((N , N ), device = DEVICE , dtype = torch .float16 )
260+ B = torch .rand ((N , N ), device = DEVICE , dtype = torch .float16 )
261+ C = torch .empty ((N , N ), device = DEVICE , dtype = torch .float16 )
261262 group_A .append (A )
262263 group_B .append (B )
263264 group_C .append (C )
@@ -267,11 +268,11 @@ def benchmark(N, provider):
267268 g_sizes += [N , N , N ]
268269 g_lds += [N , N , N ]
269270
270- d_a_ptrs = torch .tensor (A_addrs , device = "cuda" )
271- d_b_ptrs = torch .tensor (B_addrs , device = "cuda" )
272- d_c_ptrs = torch .tensor (C_addrs , device = "cuda" )
273- d_g_sizes = torch .tensor (g_sizes , dtype = torch .int32 , device = "cuda" )
274- d_g_lds = torch .tensor (g_lds , dtype = torch .int32 , device = "cuda" )
271+ d_a_ptrs = torch .tensor (A_addrs , device = DEVICE )
272+ d_b_ptrs = torch .tensor (B_addrs , device = DEVICE )
273+ d_c_ptrs = torch .tensor (C_addrs , device = DEVICE )
274+ d_g_sizes = torch .tensor (g_sizes , dtype = torch .int32 , device = DEVICE )
275+ d_g_lds = torch .tensor (g_lds , dtype = torch .int32 , device = DEVICE )
275276
276277 quantiles = [0.5 , 0.2 , 0.8 ]
277278 if provider == 'cublas' :
0 commit comments