Skip to content

Commit 98431a9

Browse files
[TEST] Fix triton_kernels IGC after bea27e3
Signed-off-by: Whitney Tsang <[email protected]>
1 parent fe0365c commit 98431a9

File tree

5 files changed

+219
-1682
lines changed

5 files changed

+219
-1682
lines changed

python/triton_kernels/tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@ def pytest_configure(config):
3737
if worker_id is not None and worker_id.startswith("gw"):
3838
import torch
3939
gpu_id = int(worker_id[2:]) # map gw0 → 0, gw1 → 1, ...
40-
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id % torch.cuda.device_count())
40+
if torch.cuda.is_available():
41+
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id % torch.cuda.device_count())

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ def matmul_ogs_torch(x, w, bias,
700700
if k > 0:
701701
out[expt] = matmul_ogs_torch(
702702
x[:, start_x:start_x+k], w[start_w:start_w+k, :], None,
703-
None, None, None, None, betas, gammas, None, round_x, round_y
703+
None, None, None, None, betas, gammas, None, round_x, round_y, device
704704
)
705705
padded_k = triton.cdiv(k, block_k) * block_k
706706
start_x += padded_k if inner_routing_data.x_is_padded else k

0 commit comments

Comments
 (0)