Skip to content

Commit faa3c0f

Browse files
authored
use static cuda ctx for triton kernel launch (#2269)
* use static cuda ctx for triton kernel launch * fix for xpu
1 parent 61e5e7f commit faa3c0f

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

gptqmodel/nn_modules/triton_utils/dequant.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -264,20 +264,21 @@ def dequant(dtype, qweight, scales, qzeros, g_idx, bits, pack_bits, maxq):
264264
numels = out.numel()
265265
grid = lambda meta: (triton.cdiv(numels, meta["X_BLOCK"]),) # noqa: E731
266266

267-
dequant_kernel[grid](
268-
g_idx,
269-
scales,
270-
qweight,
271-
qzeros,
272-
out,
273-
torch_dtype_to_triton(out_dtype),
274-
numels,
275-
pack_bits=pack_bits,
276-
maxq=maxq,
277-
bits=bits,
278-
out_features=out_features,
279-
num_groups=num_groups,
280-
)
267+
with torch.xpu.device(qweight.device) if HAS_XPU else torch.cuda.device(qweight.device):
268+
dequant_kernel[grid](
269+
g_idx,
270+
scales,
271+
qweight,
272+
qzeros,
273+
out,
274+
torch_dtype_to_triton(out_dtype),
275+
numels,
276+
pack_bits=pack_bits,
277+
maxq=maxq,
278+
bits=bits,
279+
out_features=out_features,
280+
num_groups=num_groups,
281+
)
281282
return out
282283

283284

0 commit comments

Comments
 (0)