@@ -548,36 +548,40 @@ def serialize_args(args, constants, signature):
548548
549549 cnt = 0
550550 args_dict = {"gridX" : args [cnt ], "gridY" : args [cnt + 1 ], "gridZ" : args [cnt + 2 ]}
551+ # 3: stream
552+ # 4: function
553+ # 5: packed kernel metadata
554+ assert type (args [cnt + 5 ]).__name__ == "KernelMetadata"
555+ serialize_kernel_metadata (args [cnt + 5 ], args_dict )
556+ # 6: launch_metadata
557+ # 7: launch_enter_hook
558+ # 8: launch_exit_hook
551559 args_dict ['argument_list' ] = []
552560 counts = {"tensors" : 0 , "scalars" : 0 , "karg_cnt" : 0 }
553- cnt = 4
561+ cnt += 9
554562 for arg in args [cnt :]:
555- if type (arg ).__name__ == "KernelMetadata" :
556- serialize_kernel_metadata (arg , args_dict )
557-
563+ sig_name = list (signature .keys ())[counts ['karg_cnt' ]]
558564 if isinstance (arg , torch .Tensor ):
559565 cpu_tensor = arg .cpu ()
560566 tensor_path = os .path .join (dir_path , f"tensor_{ counts ['tensors' ]} .pt" )
561567 with open (tensor_path , 'wb' ) as f :
562568 torch .save (cpu_tensor , f )
563569 new_arg = {
564570 "name" : f"tensor_{ counts ['tensors' ]} " , "type" : "tensor" , "dtype" : str (arg .dtype ), "ctype" :
565- signature [counts [ 'karg_cnt' ] ]
571+ signature [sig_name ]
566572 }
567573 args_dict ['argument_list' ].append (new_arg )
568- counts ['karg_cnt' ] += 1
569574 counts ['tensors' ] += 1
570-
571575 if isinstance (arg , numbers .Number ):
572- if counts ['karg_cnt' ] not in constants :
576+ if ( counts ['karg_cnt' ], ) not in constants . keys () :
573577 new_arg = {
574- "name" : f"scalarArg_{ counts ['scalars' ]} " , "type" : "scalar" , "value" : args [ cnt ] , "ctype" :
575- signature [counts [ 'karg_cnt' ] ]
578+ "name" : f"scalarArg_{ counts ['scalars' ]} " , "type" : "scalar" , "value" : arg , "ctype" :
579+ signature [sig_name ]
576580 }
577581 args_dict ['argument_list' ].append (new_arg )
578- counts ['karg_cnt' ] += 1
579582 counts ['scalars' ] += 1
580- cnt += 1
583+ counts ['karg_cnt' ] += 1
584+
581585 # Dump argument info as a JSON file
582586 json_path = os .path .join (dir_path , 'args_data.json' )
583587 with open (json_path , 'w' ) as json_file :
0 commit comments