Skip to content

Commit 6c98657

Browse files
ezyangpytorchmergebot
authored andcommitted
Add some Triton related suppressions that don't show on CI (pytorch#166868)
Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#166868 Approved by: https://github.com/maggiemoss, https://github.com/zou3519
1 parent 86b2d82 commit 6c98657

File tree

7 files changed

+57
-2
lines changed

7 files changed

+57
-2
lines changed

torch/_dynamo/repro/after_aot.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def generate_compiler_repro_string(
405405
# pyrefly: ignore [missing-attribute]
406406
kernel._fn_name
407407
if isinstance(kernel, JITFunction)
408+
# pyrefly: ignore # missing-attribute
408409
else kernel.fn._fn_name
409410
)
410411
fn_name = fn_name.split(".")[-1]

torch/_higher_order_ops/triton_kernel_wrap.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def generate_ttir(
264264

265265
assert isinstance(kernel, JITFunction)
266266

267+
# pyrefly: ignore # missing-attribute
267268
context = triton._C.libtriton.ir.context()
268269
target = triton.runtime.driver.active.get_current_target()
269270
backend = triton.compiler.compiler.make_backend(target)
@@ -305,6 +306,7 @@ def generate_ttir(
305306
base_tensor = torch.empty(
306307
[elements_per_dim] * len(block_shape), dtype=a.dtype
307308
)
309+
# pyrefly: ignore # bad-argument-type
308310
ordered_args[name] = TensorDescriptor.from_tensor(base_tensor, block_shape)
309311
elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)):
310312
with torch._C._DisableTorchDispatch():
@@ -368,6 +370,7 @@ def _get_specialization(args): # type: ignore[no-untyped-def]
368370

369371
target = triton.runtime.driver.active.get_current_target()
370372
backend_ = triton.compiler.compiler.make_backend(target)
373+
# pyrefly: ignore # missing-attribute
371374
return backend_.get_attrs_descriptor(args, kernel.params)
372375
else:
373376
assert (
@@ -384,6 +387,7 @@ def _get_specialization(args): # type: ignore[no-untyped-def]
384387
except TypeError: # Unknown arg `specialize_extra`
385388
# Older versions of Triton take specialize_extra as an arg to specialize_impl
386389
specialize_impl = functools.partial(
390+
# pyrefly: ignore # missing-argument
387391
triton.runtime.jit.create_specialize_impl(),
388392
specialize_extra=backend.get_arg_specialization,
389393
)
@@ -468,6 +472,7 @@ def get_signature_value(idx: int, arg: Any) -> str:
468472
if i not in constexprs
469473
}
470474

475+
# pyrefly: ignore # missing-attribute
471476
triton._C.libtriton.ir.load_dialects(context)
472477
backend.load_dialects(context)
473478

@@ -477,22 +482,29 @@ def get_signature_value(idx: int, arg: Any) -> str:
477482
# backward compatibility here.
478483
make_ir_sig_params = len(inspect.signature(src.make_ir).parameters)
479484
get_codegen_implementation_sig_params = len(
485+
# pyrefly: ignore # missing-attribute
480486
inspect.signature(backend.get_codegen_implementation).parameters
481487
)
482488
if make_ir_sig_params == 2:
489+
# pyrefly: ignore # missing-argument
483490
ttir_module = src.make_ir(options, context)
484491
elif make_ir_sig_params == 3:
492+
# pyrefly: ignore # missing-attribute
485493
codegen_fns = backend.get_codegen_implementation()
494+
# pyrefly: ignore # missing-argument
486495
ttir_module = src.make_ir(options, codegen_fns, context)
487496
elif make_ir_sig_params == 4:
488497
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
498+
# pyrefly: ignore # missing-attribute
489499
codegen_fns = backend.get_codegen_implementation(*codegen_args)
490500
module_map = backend.get_module_map()
491501
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
492502
else:
493503
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
504+
# pyrefly: ignore # missing-attribute
494505
codegen_fns = backend.get_codegen_implementation(*codegen_args)
495506
module_map = backend.get_module_map()
507+
# pyrefly: ignore # bad-argument-count
496508
ttir_module = src.make_ir(target, options, codegen_fns, module_map, context)
497509
if not ttir_module.verify():
498510
raise RuntimeError("Verification for TTIR module has failed")
@@ -1102,6 +1114,7 @@ def triton_kernel_wrapper_mutation_dense(
11021114
from triton.tools.tensor_descriptor import TensorDescriptor
11031115

11041116
block_shape = stable_meta[0]
1117+
# pyrefly: ignore # bad-argument-type
11051118
kwargs[k] = TensorDescriptor.from_tensor(tensor, block_shape)
11061119

11071120
# move as many positional arguments from dicts to args as we
@@ -1658,6 +1671,7 @@ def call_triton_kernel(
16581671
"Passing multiple @triton.autotune decorators is not supported. "
16591672
"Please use a single @triton.autotune decorator instead."
16601673
)
1674+
# pyrefly: ignore # missing-attribute
16611675
iter_kernel = iter_kernel.fn
16621676

16631677
# Process the @triton.heuristics decorator:
@@ -1868,6 +1882,7 @@ def call_triton_kernel(
18681882

18691883
# Both for grid's meta as well as for the kernel, we need combined
18701884
# args and kwargs combined and normalized
1885+
# pyrefly: ignore # missing-attribute
18711886
combined_args_raw = {**dict(zip(variable.kernel.arg_names, args)), **kwargs}
18721887

18731888
# precompute the grid for the kernel
@@ -2061,6 +2076,7 @@ def __init__(
20612076
kernel_idx: Optional[int],
20622077
grid: Optional["TritonGridType"],
20632078
) -> None:
2079+
# pyrefly: ignore # bad-assignment
20642080
self.kernel = None
20652081
self.grid = None
20662082
tracing_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)

torch/_inductor/codegen/wrapper.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,11 +289,15 @@ def traverse(cur_kernel):
289289
if isinstance(symbol, JITFunction):
290290
compile_wrapper.newline()
291291
compile_wrapper.writeline("@triton.jit")
292+
# pyrefly: ignore # missing-attribute
292293
compile_wrapper.splice(symbol.src, strip=True)
293294
symbols_included.add(symbol_name)
294295
traverse(symbol)
295296
elif hasattr(triton, "constexpr_function") and isinstance(
296-
symbol, triton.runtime.jit.ConstexprFunction
297+
# pyrefly: ignore # missing-attribute
298+
symbol,
299+
# pyrefly: ignore # missing-attribute
300+
triton.runtime.jit.ConstexprFunction,
297301
):
298302
compile_wrapper.newline()
299303
compile_wrapper.writeline("@triton.constexpr_function")

torch/_inductor/codegen/wrapper_fxir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,9 @@ def tune_kernel(tuner: CachingAutotuner, call_args: Sequence[Any]) -> None:
949949
from triton.runtime import driver
950950

951951
log.info("Autotuning Triton kernel %s at compile time.", kernel_name)
952+
# pyrefly: ignore # missing-attribute
952953
device = driver.active.get_current_device()
954+
# pyrefly: ignore # missing-attribute
953955
stream = driver.active.get_current_stream(device)
954956

955957
def node_to_tuning_arg(arg: Any) -> Any:

torch/_inductor/ir.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6985,6 +6985,7 @@ def get_kernel_and_metadata(self) -> tuple[Kernel, Any, list[str], list[str]]:
69856985

69866986
configs = kernel.configs
69876987
kernel = kernel.fn
6988+
# pyrefly: ignore # bad-return
69886989
return kernel, configs, restore_value_args, reset_to_zero_args
69896990

69906991
@override
@@ -7140,7 +7141,10 @@ def __init__(
71407141
self.mutable_args = [
71417142
kernel_args[key]
71427143
for key in identify_mutated_tensors(
7143-
kernel, {**kernel_args, **autotuned_kwargs}, tma_descriptor_metadata
7144+
# pyrefly: ignore # bad-argument-type
7145+
kernel,
7146+
{**kernel_args, **autotuned_kwargs},
7147+
tma_descriptor_metadata,
71447148
)
71457149
]
71467150

torch/_inductor/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2555,11 +2555,14 @@ def get_device_tflops(dtype: torch.dtype) -> float:
25552555
return get_max_simd_tflops(torch.float32, sm_clock)
25562556
else:
25572557
if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
2558+
# pyrefly: ignore # missing-argument
25582559
return get_max_tensorcore_tflops(dtype)
25592560

25602561
if torch.backends.cuda.matmul.allow_tf32:
2562+
# pyrefly: ignore # missing-argument
25612563
return get_max_tensorcore_tflops(torch.float32)
25622564
else:
2565+
# pyrefly: ignore # missing-argument
25632566
return get_max_simd_tflops(torch.float32)
25642567

25652568

@@ -2573,6 +2576,7 @@ def get_gpu_dram_gbps() -> int:
25732576
def get_gpu_shared_memory() -> int:
25742577
from triton.runtime import driver
25752578

2579+
# pyrefly: ignore # missing-attribute
25762580
return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0)
25772581

25782582

torch/sparse/_triton_ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,17 +1302,28 @@ def kernel(grid, *sliced_tensors):
13021302
# pyrefly: ignore [unsupported-operation]
13031303
_bsr_strided_addmm_kernel[grid](
13041304
*ptr_stride_extractor(*sliced_tensors),
1305+
# pyrefly: ignore # bad-argument-count
13051306
beta,
13061307
alpha,
1308+
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
13071309
beta_is_one=beta == 1,
1310+
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
13081311
beta_is_nonzero=beta != 0,
1312+
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
13091313
alpha_is_one=alpha == 1,
1314+
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
13101315
left_alpha_is_one=left_alpha_is_one,
1316+
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
13111317
right_alpha_is_one=right_alpha_is_one,
1318+
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
13121319
BLOCKSIZE_ROW=BM,
1320+
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
13131321
BLOCKSIZE_INNER=BK,
1322+
# pyrefly: ignore # bad-keyword-argument
13141323
BLOCKSIZE_COL=BN,
1324+
# pyrefly: ignore # bad-keyword-argument
13151325
allow_tf32=dot_out_dtype == tl.float32,
1326+
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
13161327
acc_dtype=dot_out_dtype,
13171328
**meta,
13181329
)
@@ -1633,12 +1644,17 @@ def kernel(grid, *sliced_tensors):
16331644
beta,
16341645
is_beta_zero,
16351646
*blocksize,
1647+
# pyrefly: ignore # bad-argument-count
16361648
k,
16371649
tile_k,
16381650
*ptr_stride_extractor(*sliced_tensors),
1651+
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
16391652
acc_dtype=acc_dtype,
1653+
# pyrefly: ignore # bad-keyword-argument, bad-argument-type
16401654
allow_tf32=allow_tf32,
1655+
# pyrefly: ignore # unexpected-keyword
16411656
num_stages=1,
1657+
# pyrefly: ignore # unexpected-keyword
16421658
num_warps=4,
16431659
)
16441660

@@ -1923,6 +1939,7 @@ def bsr_softmax(input, max_row_nnz=None):
19231939
def kernel(grid, *sliced_tensors):
19241940
_bsr_softmax_kernel[grid](
19251941
*ptr_stride_extractor(*sliced_tensors),
1942+
# pyrefly: ignore # bad-argument-count
19261943
row_block,
19271944
col_block,
19281945
max_row_nnz,
@@ -2096,8 +2113,11 @@ def grid(META):
20962113
if "allow_tf32" not in meta:
20972114
meta.update(allow_tf32=dot_out_dtype == tl.float32)
20982115
_scatter_mm2_kernel[grid](
2116+
# pyrefly: ignore # bad-argument-type
20992117
M,
2118+
# pyrefly: ignore # bad-argument-type
21002119
K,
2120+
# pyrefly: ignore # bad-argument-type
21012121
N,
21022122
blocks,
21032123
blocks.stride(0),
@@ -2116,7 +2136,9 @@ def grid(META):
21162136
pq_indices,
21172137
pq_indices.stride(0),
21182138
pq_indices.stride(1),
2139+
# pyrefly: ignore # bad-argument-type
21192140
dot_out_dtype=dot_out_dtype,
2141+
# pyrefly: ignore # bad-argument-type
21202142
**meta,
21212143
)
21222144

@@ -2299,6 +2321,7 @@ def grid(META):
22992321
_scatter_mm6_kernel[grid](
23002322
B,
23012323
Ms,
2324+
# pyrefly: ignore # bad-argument-type
23022325
Ks,
23032326
N,
23042327
blocks,
@@ -2317,6 +2340,7 @@ def grid(META):
23172340
r_offsets,
23182341
p_offsets,
23192342
q_offsets,
2343+
# pyrefly: ignore # bad-argument-type
23202344
dot_out_dtype=dot_out_dtype,
23212345
**meta,
23222346
)

0 commit comments

Comments
 (0)