@@ -301,7 +301,7 @@ def load_counters(profile:list[ProfileEvent]) -> None:
301301 steps :list [dict ] = []
302302 if (pmc := v .get (ProfilePMCEvent )):
303303 steps .append (create_step ("PMC" , ("/prg-pmc" , len (ctxs ), len (steps )), pmc ))
304- all_counters [(name , run_number [k ], k )] = pmc [0 ]
304+ all_counters [(name , run_number [k ], pname )] = pmc [0 ]
305305 # to decode a SQTT trace, we need the raw stream, program binary and device properties
306306 if (sqtt := v .get (ProfileSQTTEvent )):
307307 for e in sqtt :
@@ -345,7 +345,7 @@ def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[
345345 # * init decoder
346346 from extra .sqtt .roc import decode
347347 base = unwrap (p .base )
348- addr_table = amd_decode (device_props [p .device ]["gfx_target_version" ], unwrap ( p . lib ) )
348+ addr_table = amd_decode (unwrap ( p . lib ), device_props [p .device ]["gfx_target_version" ], )
349349 disasm :dict [int , tuple [str , int ]] = {addr + base :(inst .disasm (), inst .size ()) for addr , inst in addr_table .items ()}
350350 rctx = decode (data , {p .name :disasm })
351351 cu_events :dict [str , list [ProfileEvent ]] = {}
@@ -432,7 +432,7 @@ def amd_readelf(lib:bytes) -> list[dict]:
432432 ".group_segment_fixed_size" :"LDS size" , ".private_segment_fixed_size" :"Scratch size" }
433433 return [{"label" :label , "value" :v } for k ,label in keys .items () if (v := notes ["amdhsa.kernels" ][0 ][k ]) > 0 ]
434434
435- def amd_decode (target : int , lib : bytes ) -> dict [int , Any ]: # Any is the Inst class from extra.assembly.amd.dsl
435+ def amd_decode (lib : bytes , target : int ) -> dict [int , Any ]: # Any is the Inst class from extra.assembly.amd.dsl
436436 from tinygrad .runtime .support .elf import elf_loader
437437 from extra .assembly .amd import detect_format
438438 from extra .assembly .amd .dsl import Inst
@@ -460,7 +460,7 @@ def parse_branch(inst) -> int|None:
460460COND_TAKEN , COND_NOT_TAKEN , UNCOND = range (3 )
461461def amdgpu_cfg (lib :bytes , target :int ) -> dict :
462462 # decode
463- pc_table = amd_decode (target , lib )
463+ pc_table = amd_decode (lib , target )
464464 # get leaders
465465 leaders :set [int ] = {next (iter (pc_table ))}
466466 for pc , inst in pc_table .items ():
@@ -509,15 +509,14 @@ def get_render(query:str) -> dict:
509509 if fmt == "asm" :
510510 ret :dict = {"metadata" :[]}
511511 if data .device .startswith ("AMD" ) and data .lib is not None :
512- with soft_err (lambda err : ret .update (err )):
513- ret .update (amdgpu_cfg (lib := data .lib , device_props [data .device ]["gfx_target_version" ]))
514- with soft_err (lambda err : ret ["metadata" ].append (err )): ret ["metadata" ].append (amd_readelf (lib ))
512+ with soft_err (lambda err : ret .update (err )): ret .update (amdgpu_cfg (data .lib , device_props [data .device ]["gfx_target_version" ]))
513+ with soft_err (lambda err : ret ["metadata" ].append (err )): ret ["metadata" ].append (amd_readelf (data .lib ))
515514 else : ret ["src" ] = get_stdout (lambda : (compiler := Device [data .device ].compiler ).disassemble (compiler .compile (data .src )))
516515 return ret
517516 if fmt == "all-pmc" :
518517 durations , pmc = data
519518 ret = {"cols" :{}, "rows" :[]}
520- for (name , n , k ),events in data [ 1 ] .items ():
519+ for (name , n , k ),events in pmc .items ():
521520 pmc_table = unpack_pmc (events )
522521 ret ["cols" ].update ([(r [0 ], None ) for r in pmc_table ["rows" ]])
523522 ret ["rows" ].append ((name , durations [k ][n - 1 ], * [r [1 ] for r in pmc_table ["rows" ]]))
0 commit comments