@@ -171,9 +171,9 @@ def group_gemm_fn(group_A, group_B):
171171 g_lds += [A .stride (0 ), B .stride (0 ), C .stride (0 )]
172172
173173 # note these are device tensors
174- d_a_ptrs = torch .tensor (A_addrs , device = DEVICE )
175- d_b_ptrs = torch .tensor (B_addrs , device = DEVICE )
176- d_c_ptrs = torch .tensor (C_addrs , device = DEVICE )
174+ d_a_ptrs = torch .tensor (A_addrs , device = DEVICE , dtype = torch . uint64 )
175+ d_b_ptrs = torch .tensor (B_addrs , device = DEVICE , dtype = torch . uint64 )
176+ d_c_ptrs = torch .tensor (C_addrs , device = DEVICE , dtype = torch . uint64 )
177177 d_g_sizes = torch .tensor (g_sizes , dtype = torch .int32 , device = DEVICE )
178178 d_g_lds = torch .tensor (g_lds , dtype = torch .int32 , device = DEVICE )
179179 # we use a fixed number of CTA, and it's auto-tunable
@@ -277,9 +277,9 @@ def benchmark(N, provider):
277277 g_sizes += [N , N , N ]
278278 g_lds += [N , N , N ]
279279
280- d_a_ptrs = torch .tensor (A_addrs , device = DEVICE )
281- d_b_ptrs = torch .tensor (B_addrs , device = DEVICE )
282- d_c_ptrs = torch .tensor (C_addrs , device = DEVICE )
280+ d_a_ptrs = torch .tensor (A_addrs , device = DEVICE , dtype = torch . uint64 )
281+ d_b_ptrs = torch .tensor (B_addrs , device = DEVICE , dtype = torch . uint64 )
282+ d_c_ptrs = torch .tensor (C_addrs , device = DEVICE , dtype = torch . uint64 )
283283 d_g_sizes = torch .tensor (g_sizes , dtype = torch .int32 , device = DEVICE )
284284 d_g_lds = torch .tensor (g_lds , dtype = torch .int32 , device = DEVICE )
285285
0 commit comments