Skip to content

Commit d7cc43c

Browse files
authored
Fix flex_attention test
Differential Revision: D81257355 Pull Request resolved: #373
1 parent 94afdc9 commit d7cc43c

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

test/test_gpu/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def _run_one_operator(args: List[str]):
8989
del op
9090
tb_args.mode = "bwd"
9191
if tb_args.op in BWD_ARGS_OPS:
92-
extra_args.extend(BWD_ARGS_OPS[tb_args.op])
92+
args.extend(BWD_ARGS_OPS[tb_args.op])
93+
tb_args, extra_args = parser.parse_known_args(args)
9394
op = Operator(tb_args=tb_args, extra_args=extra_args)
9495
op.run()
9596
check_ci_output(op)

tritonbench/operators/flex_attention/operator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,7 @@ def sdpa_fn():
415415
with sdpa_kernel([SDPBackend.CUDNN_ATTENTION]):
416416
return sdpa(q, k, v, is_causal=is_causal)
417417
except RuntimeError as e:
418-
print(f"[SKIP] cuDNN backend failed: {e}")
419-
return None
418+
raise NotImplementedError(str(e))
420419

421420
return sdpa_fn
422421

tritonbench/utils/triton_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,8 +1507,8 @@ def _init_extra_metrics() -> Dict[str, Any]:
15071507
self.dump_ir(input_id, fn)
15081508
except torch.cuda.OutOfMemoryError:
15091509
metrics.error_msg = "CUDA OOM"
1510-
except NotImplementedError:
1511-
metrics.error_msg = "not supported"
1510+
except NotImplementedError as e:
1511+
metrics.error_msg = str(e)
15121512
except Exception as e:
15131513
if not self.tb_args.keep_going:
15141514
raise

0 commit comments

Comments
 (0)