Skip to content

Commit 1998e0b

Browse files
authored
nv: add prof props to dev (tinygrad#14437)
1 parent 7a9dee4 commit 1998e0b

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

extra/nv_pma/decode.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -169,26 +169,25 @@ def print_aggregated(samples:list[tuple[PMASample, int]]) -> None:
169169
print("Usage: python decode.py <pkl_file> [--raw] [--sm=0xNNN]")
170170
sys.exit(1)
171171

172-
# Parse --sm=0xNNN argument
173-
sm_version = 0x800 # default to Ampere
174-
for arg in sys.argv:
175-
if arg.startswith("--sm="):
176-
sm_version = int(arg[5:], 0)
177-
178172
with open(sys.argv[1], "rb") as f:
179173
data = pickle.load(f)
180174

181-
if isinstance(data, dict): dumps = list(enumerate(data["pma_raw_dumps"]))
182-
else: dumps = [(i, e.blob) for i, e in enumerate(e for e in data if type(e).__name__ == "ProfilePMAEvent")]
183-
184-
record_size = 9 if sm_version >= 0x890 else 8
185-
print(f"SM version: 0x{sm_version:x}, using {record_size}-byte records")
186-
187-
for dump_idx, raw in dumps:
175+
if isinstance(data, dict):
176+
sm_version = 0x800 # default to Ampere
177+
for arg in sys.argv:
178+
if arg.startswith("--sm="): sm_version = int(arg[5:], 0)
179+
dumps = [(i, x, sm_version) for i, x in enumerate(data["pma_raw_dumps"])]
180+
else:
181+
devs = {e.device: e for e in data if type(e).__name__ == "ProfileDeviceEvent"}
182+
dumps = []
183+
for i, e in enumerate(e for e in data if type(e).__name__ == "ProfilePMAEvent"):
184+
dumps.append((i, e.blob, devs[e.device].props.get('sm_version', 0x800)))
185+
186+
for dump_idx, raw, sm_ver in dumps:
188187
print(f"\n{'='*60}\nDump {dump_idx} ({len(raw)} bytes, {len(raw)//32} packets)\n{'='*60}")
189-
if "--raw" in sys.argv: print_packets(raw, sm_version)
188+
if "--raw" in sys.argv: print_packets(raw, sm_ver)
190189
else:
191-
samples = list(decode(raw, sm_version))
190+
samples = list(decode(raw, sm_ver))
192191
print(f"\nDecoded {len(samples)} samples:")
193192
print_samples(samples)
194193
print_aggregated(samples)

tinygrad/runtime/ops_nv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,3 +828,5 @@ def _prof_readback(self) -> bytes|None:
828828
self.iface.rm_control(self.profiler, nv_gpu.NVB0CC_CTRL_CMD_PMA_STREAM_UPDATE_GET_PUT,
829829
nv_gpu.struct_NVB0CC_CTRL_PMA_STREAM_UPDATE_GET_PUT_PARAMS(bytesConsumed=params.bytesAvailable))
830830
return pma_data
831+
832+
def device_props(self): return {'arch': self.arch, 'sm_version': self.sm_version}

0 commit comments

Comments
 (0)