1010import inspect
1111import math
1212
13- kLiteralTensorSize = 64
14-
1513
1614def apply_templates (forward_code : str ) -> str :
1715 tab = " "
@@ -54,7 +52,7 @@ def process_tensor(tensor):
5452
5553 info = tensor_info (tensor )
5654 if tensor .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .int64 ]:
57- if tensor .numel () < kLiteralTensorSize :
55+ if tensor .numel () < 64 :
5856 return {
5957 "type" : "small_int_tensor" ,
6058 "data" : tensor .clone (),
@@ -67,7 +65,7 @@ def process_tensor(tensor):
6765 "max_val" : tensor .max ().item (),
6866 "info" : info ,
6967 }
70- elif tensor .numel () < kLiteralTensorSize :
68+ elif tensor .numel () < 64 :
7169 return {"type" : "small_tensor" , "data" : tensor .clone (), "info" : info }
7270 else :
7371 return {"type" : "random_tensor" , "info" : info }
@@ -82,7 +80,7 @@ def process_tensor(tensor):
8280 def handle_named_tensors (tensor ):
8381 info = tensor_info (tensor )
8482 if tensor .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .int64 ]:
85- if tensor .numel () < kLiteralTensorSize :
83+ if tensor .numel () < 64 :
8684 return {
8785 "info" : info ,
8886 "data" : tensor .clone (),
@@ -95,7 +93,7 @@ def handle_named_tensors(tensor):
9593 "max_val" : tensor .max ().item (),
9694 "type" : "big_int_tensor_by_range" ,
9795 }
98- if tensor .numel () < kLiteralTensorSize :
96+ if tensor .numel () < 64 :
9997 return {"info" : info , "data" : tensor .clone (), "type" : "small_tensor" }
10098 else :
10199 return {"info" : info , "data" : None , "type" : "random_tensor" }
0 commit comments