|
1 | 1 | import argparse |
2 | 2 | import logging |
| 3 | +import csv |
| 4 | +import os |
3 | 5 |
|
4 | 6 | from typing import Any, Callable, List, Optional |
5 | 7 |
|
|
33 | 35 | HAS_TMA = False |
34 | 36 | logger.warning(f"Failed to import TMA: {e}") |
35 | 37 |
|
36 | | - |
37 | 38 | def parse_args(args): |
38 | 39 | parser = argparse.ArgumentParser(description="TritonBench fp8_gemm") |
39 | 40 | parser.add_argument("--llama", action="store_true") |
40 | 41 | parser.add_argument("--scaling_rowwise", action="store_true") |
41 | 42 | parser.add_argument("--m", type=int) |
42 | 43 | parser.add_argument("--k", type=int) |
43 | 44 | parser.add_argument("--n", type=int) |
| 45 | + parser.add_argument("--filepath", type=str, default=None) |
44 | 46 | return parser.parse_args(args) |
45 | 47 |
|
| 48 | +def read_fp8_shapes(filepath): |
| 49 | + fp8_shapes = [] |
| 50 | + try: |
| 51 | + with open(filepath, 'r', newline='') as csvfile: |
| 52 | + filtered_lines = (line for line in csvfile if line.strip() and not line.lstrip().startswith('#')) |
| 53 | + reader = csv.reader(filtered_lines) |
| 54 | + for row in reader: |
| 55 | + fp8_shapes.append(tuple(map(int, row))) |
| 56 | + except Exception as e: |
| 57 | + logger.error(f"Failed to read fp8 shapes from {filepath}: {e}") |
| 58 | + raise e |
| 59 | + return fp8_shapes |
| 60 | + |
46 | 61 |
|
47 | 62 | class Operator(BenchmarkOperator): |
48 | 63 | DEFAULT_METRICS = ["tflops", "gbps", "latency"] |
@@ -70,6 +85,10 @@ def args(m, n, k): |
70 | 85 | yield args(m, n, k) |
71 | 86 | elif self.extra_args.m: |
72 | 87 | yield args(self.extra_args.m, self.extra_args.n, self.extra_args.k) |
| 88 | + elif self.extra_args.filepath: |
| 89 | + fp8_shapes = read_fp8_shapes(self.extra_args.filepath) |
| 90 | + for m, n, k in fp8_shapes: |
| 91 | + yield args(m, n, k) |
73 | 92 | else: |
74 | 93 | for i in range(10, 15): |
75 | 94 | for j in range(0, 4): |
@@ -114,8 +133,8 @@ def pt2_fp8_gemm(self, a, b) -> Callable: |
114 | 133 | scale_b = torch.ones((1, N), dtype=torch.float32, device=b.device) |
115 | 134 | out_dtype = torch.bfloat16 |
116 | 135 | else: |
117 | | - scale_a = torch.tensor(1.0, device=a.device) |
118 | | - scale_b = torch.tensor(1.0, device=a.device) |
| 136 | + scale_a = torch.tensor(1.0, dtype=torch.float32, device=a.device) |
| 137 | + scale_b = torch.tensor(1.0, dtype=torch.float32, device=a.device) |
119 | 138 | out_dtype = torch.float16 |
120 | 139 | f = lambda a, b: torch._scaled_mm( |
121 | 140 | a, b, scale_a, scale_b, use_fast_accum=True, out_dtype=out_dtype |
|
0 commit comments