|
8 | 8 | import functools
|
9 | 9 | import logging
|
10 | 10 | import os
|
11 |
| -from typing import List, Optional, Tuple, Union |
| 11 | +from typing import Dict, List, Optional, Tuple, Union |
12 | 12 |
|
13 | 13 | import torch
|
14 | 14 | import triton # @manual
|
@@ -1281,15 +1281,15 @@ def matmul_fp8_row(
|
1281 | 1281 | output += bias[None, :]
|
1282 | 1282 | return output.to(c.dtype)
|
1283 | 1283 |
|
1284 |
| - def grid(META): |
| 1284 | + def grid(META: Dict[str, int]) -> Tuple[int, int]: |
1285 | 1285 | return (
|
1286 | 1286 | triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
1287 | 1287 | META["SPLIT_K"],
|
1288 | 1288 | )
|
1289 | 1289 |
|
1290 | 1290 | NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
1291 | 1291 |
|
1292 |
| - def persistent_grid(META): |
| 1292 | + def persistent_grid(META: Dict[str, int]) -> Tuple[int]: |
1293 | 1293 | return (
|
1294 | 1294 | min(
|
1295 | 1295 | NUM_SMS,
|
@@ -1337,8 +1337,9 @@ def persistent_grid(META):
|
1337 | 1337 | desc_helper.init_tma_descriptor("b_scale")
|
1338 | 1338 | desc_helper.init_tma_descriptor("bias")
|
1339 | 1339 |
|
1340 |
| - def persistent_grid_tma_ws(META): |
| 1340 | + def persistent_grid_tma_ws(META: Dict[str, int]) -> Tuple[int]: |
1341 | 1341 | nonlocal desc_helper # noqa: F824
|
| 1342 | + assert a_scale is not None # Type narrowing for Pyre |
1342 | 1343 | desc_helper.fill_2d_tma_descriptor(
|
1343 | 1344 | "a",
|
1344 | 1345 | a.data_ptr(),
|
@@ -1449,8 +1450,9 @@ def persistent_grid_tma_ws(META):
|
1449 | 1450 | desc_helper.init_tma_descriptor("b_scale")
|
1450 | 1451 | desc_helper.init_tma_descriptor("bias")
|
1451 | 1452 |
|
1452 |
| - def persistent_grid_tma(META): |
| 1453 | + def persistent_grid_tma(META: Dict[str, int]) -> Tuple[int]: |
1453 | 1454 | nonlocal desc_helper # noqa: F824
|
| 1455 | + assert a_scale is not None # Type narrowing for Pyre |
1454 | 1456 | desc_helper.fill_2d_tma_descriptor(
|
1455 | 1457 | "a",
|
1456 | 1458 | a.data_ptr(),
|
@@ -2111,7 +2113,7 @@ def matmul_fp8_block(
|
2111 | 2113 | raise Exception("'b_scale' must be on the same device as 'a'")
|
2112 | 2114 |
|
2113 | 2115 | # noqa: E731:
|
2114 |
| - def grid(META): |
| 2116 | + def grid(META: Dict[str, int]) -> Tuple[int, int]: |
2115 | 2117 | return (
|
2116 | 2118 | triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
|
2117 | 2119 | META["SPLIT_K"],
|
@@ -4254,7 +4256,7 @@ def dequantize_fp8_row(
|
4254 | 4256 | M = xq.shape[0]
|
4255 | 4257 | use_int64 = xq.numel() > 2**31
|
4256 | 4258 |
|
4257 |
| - def grid(meta): |
| 4259 | + def grid(meta: Dict[str, int]) -> Tuple[int]: |
4258 | 4260 | return (triton.cdiv(M, meta["BLOCK_M"]),)
|
4259 | 4261 |
|
4260 | 4262 | with torch.cuda.device(xq.device.index):
|
@@ -4365,7 +4367,7 @@ def dequantize_fp8_packed_row(
|
4365 | 4367 | M = actual_xq.shape[0]
|
4366 | 4368 | use_int64 = actual_xq.numel() > 2**31
|
4367 | 4369 |
|
4368 |
| - def grid(meta): |
| 4370 | + def grid(meta: Dict[str, int]) -> Tuple[int]: |
4369 | 4371 | return (triton.cdiv(M, meta["BLOCK_M"]),)
|
4370 | 4372 |
|
4371 | 4373 | with torch.cuda.device(actual_xq.device.index):
|
@@ -4445,7 +4447,7 @@ def dequantize_fp8_block(
|
4445 | 4447 | M, K = xq.size()
|
4446 | 4448 | x_dequant = torch.empty_like(xq, dtype=torch.bfloat16)
|
4447 | 4449 |
|
4448 |
| - def grid(meta): |
| 4450 | + def grid(meta: Dict[str, int]) -> Tuple[int, int]: |
4449 | 4451 | return (
|
4450 | 4452 | triton.cdiv(M, meta["BLOCK_M"]),
|
4451 | 4453 | triton.cdiv(K, meta["BLOCK_K"]),
|
|
0 commit comments