Skip to content

Commit 36c506b

Browse files
generatedunixname893464919433493facebook-github-bot
authored andcommitted
Fix Pyre missing annotations in fp8_gemm grid functions (#4836)
Summary: X-link: facebookresearch/FBGEMM#1862 Pull Request resolved: #4836 ## Instructions about RACER Diffs: **Commandeer this diff (recommended) or land with accept2ship tag.** *This feature is still in BETA and we are continuously improving it. Your constructive feedback would help improving RACER and highly appreciated.* This diff was pre-created by Racer AI agent for your convenience on top of T234955854. How-to-code instruction is provided by oncall [Sergey Parshin](https://www.internalfb.com/profile/view/100018888769456). For questions or suggestions please post in [RACER Maintainer](https://fb.workplace.com/groups/742040101615185) group. You will receive **FULL CREDIT** (ETS) for this diff if you either: - [**Recommended**]Commandeer and land this diff after another reviewer's approval - Accept it and ship it after required approvals are provided. Once landed, feel free to claim the associated task and the EYS project. This diff fixes pyre-missing-annotations warnings identified by Quality Insight from [Monetization codehub](https://fburl.com/quality/0kw8oby2) - If you are happy with the changes, commandeer it if minor edits are needed. (**we encourage commandeer to get the diff credit**) - If you are not happy with the changes, please comment on the diff with clear actions and send it back to the author. Racer will pick it up and re-generate. - If you really feel the Racer is not helping with this change (alas, some complex changes are hard for AI) feel free to abandon this diff. - **For M10N reviewers:** as you review AI-generated diffs, we ask you to give them the same priority as human-generated diffs, and take action in a timely manner by either accepting, rejecting, or resigning as a reviewer. For diffs that don't meet the quality bar (e.g. code doesn't compile, not readable or introduces functionality regressions), we ask that you use the following hashtags to provide clear signals to improve our tools - `#monlowqualitydiff` `#monwrongreviewerdiff` ## Summary: This diff adds missing type annotations to grid functions in the fp8_gemm.py file to resolve Pyre missing annotations warnings. The changes include: 1. Added `Dict` to the typing imports 2. Added proper type annotations to 8 grid functions throughout the file: - Parameter type: `META: Dict[str, int]` (configuration dictionary) - Return types: `-> Tuple[int]` or `-> Tuple[int, int]` (grid dimensions) 3. Added type narrowing assertions for Optional parameters to fix Pyre type checking errors The grid functions are used by Triton kernels for GPU computation and the type annotations ensure better type safety and code documentation without changing functionality. The additional assertions ensure that Optional parameters are properly handled in code paths where they are required to be non-None. --- > Generated by [RACER](https://www.internalfb.com/wiki/RACER_(Risk-Aware_Code_Editing_and_Refactoring)/), powered by [Confucius](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/) [Session](https://www.internalfb.com/confucius?session_id=9de3f2aa-8fc3-11f0-a605-cb22353bcb8e&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=9de3f2aa-8fc3-11f0-a605-cb22353bcb8e&tab=Trace) Reviewed By: q10 Differential Revision: D81925341 fbshipit-source-id: fc2d431f667a16aeec24e0bf50391121c10c431d
1 parent 3a5e741 commit 36c506b

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import functools
99
import logging
1010
import os
11-
from typing import List, Optional, Tuple, Union
11+
from typing import Dict, List, Optional, Tuple, Union
1212

1313
import torch
1414
import triton # @manual
@@ -1281,15 +1281,15 @@ def matmul_fp8_row(
12811281
output += bias[None, :]
12821282
return output.to(c.dtype)
12831283

1284-
def grid(META):
1284+
def grid(META: Dict[str, int]) -> Tuple[int, int]:
12851285
return (
12861286
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
12871287
META["SPLIT_K"],
12881288
)
12891289

12901290
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
12911291

1292-
def persistent_grid(META):
1292+
def persistent_grid(META: Dict[str, int]) -> Tuple[int]:
12931293
return (
12941294
min(
12951295
NUM_SMS,
@@ -1337,8 +1337,9 @@ def persistent_grid(META):
13371337
desc_helper.init_tma_descriptor("b_scale")
13381338
desc_helper.init_tma_descriptor("bias")
13391339

1340-
def persistent_grid_tma_ws(META):
1340+
def persistent_grid_tma_ws(META: Dict[str, int]) -> Tuple[int]:
13411341
nonlocal desc_helper # noqa: F824
1342+
assert a_scale is not None # Type narrowing for Pyre
13421343
desc_helper.fill_2d_tma_descriptor(
13431344
"a",
13441345
a.data_ptr(),
@@ -1449,8 +1450,9 @@ def persistent_grid_tma_ws(META):
14491450
desc_helper.init_tma_descriptor("b_scale")
14501451
desc_helper.init_tma_descriptor("bias")
14511452

1452-
def persistent_grid_tma(META):
1453+
def persistent_grid_tma(META: Dict[str, int]) -> Tuple[int]:
14531454
nonlocal desc_helper # noqa: F824
1455+
assert a_scale is not None # Type narrowing for Pyre
14541456
desc_helper.fill_2d_tma_descriptor(
14551457
"a",
14561458
a.data_ptr(),
@@ -2111,7 +2113,7 @@ def matmul_fp8_block(
21112113
raise Exception("'b_scale' must be on the same device as 'a'")
21122114

21132115
# noqa: E731:
2114-
def grid(META):
2116+
def grid(META: Dict[str, int]) -> Tuple[int, int]:
21152117
return (
21162118
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
21172119
META["SPLIT_K"],
@@ -4254,7 +4256,7 @@ def dequantize_fp8_row(
42544256
M = xq.shape[0]
42554257
use_int64 = xq.numel() > 2**31
42564258

4257-
def grid(meta):
4259+
def grid(meta: Dict[str, int]) -> Tuple[int]:
42584260
return (triton.cdiv(M, meta["BLOCK_M"]),)
42594261

42604262
with torch.cuda.device(xq.device.index):
@@ -4365,7 +4367,7 @@ def dequantize_fp8_packed_row(
43654367
M = actual_xq.shape[0]
43664368
use_int64 = actual_xq.numel() > 2**31
43674369

4368-
def grid(meta):
4370+
def grid(meta: Dict[str, int]) -> Tuple[int]:
43694371
return (triton.cdiv(M, meta["BLOCK_M"]),)
43704372

43714373
with torch.cuda.device(actual_xq.device.index):
@@ -4445,7 +4447,7 @@ def dequantize_fp8_block(
44454447
M, K = xq.size()
44464448
x_dequant = torch.empty_like(xq, dtype=torch.bfloat16)
44474449

4448-
def grid(meta):
4450+
def grid(meta: Dict[str, int]) -> Tuple[int, int]:
44494451
return (
44504452
triton.cdiv(M, meta["BLOCK_M"]),
44514453
triton.cdiv(K, meta["BLOCK_K"]),

0 commit comments

Comments
 (0)