7777 ScalarType ,
7878 String ,
7979 Tensor ,
80+ TensorDataLocation ,
8081 TensorList ,
8182 TensorShapeDynamism ,
8283)
@@ -372,8 +373,8 @@ def _save_new_const_tensor(
372373 spec : TensorSpec ,
373374 buffer_data : bytes ,
374375 hashed : str ,
375- allocation_info : Optional [AllocationDetails ],
376- constant_tag : str ,
376+ allocation_info : Optional [AllocationDetails ] = None ,
377+ constant_tag : Optional [ str ] = None ,
377378 ) -> int :
378379 """Saves a new constant tensor to the constant buffer and returns the buffer idx"""
379380
@@ -395,7 +396,7 @@ def _save_new_const_tensor(
395396 if (
396397 spec .extra_tensor_info is not None
397398 and spec .extra_tensor_info .fully_qualified_name is not None
398- and spec .extra_tensor_info .location == DataLocation .EXTERNAL
399+ and spec .extra_tensor_info .location == TensorDataLocation .EXTERNAL
399400 ):
400401 assert (
401402 constant_tag is not None
@@ -460,7 +461,7 @@ def _tensor_spec_to_evalue(
460461 )
461462 elif (
462463 spec .extra_tensor_info is not None
463- and spec .extra_tensor_info .location == DataLocation .EXTERNAL
464+ and spec .extra_tensor_info .location == TensorDataLocation .EXTERNAL
464465 ):
465466 buffer_idx = self .program_state .external_constant_hash .get (hashed , - 1 )
466467 else :
@@ -1614,7 +1615,7 @@ def placeholder(
16141615 ), "constant tagged tensors require a fully qualified name"
16151616 if spec .extra_tensor_info is None :
16161617 spec .extra_tensor_info = ExtraTensorInfo (
1617- fully_qualified_name = fqn , location = DataLocation .EXTERNAL
1618+ fully_qualified_name = fqn , location = TensorDataLocation .EXTERNAL
16181619 )
16191620 else :
16201621 spec .extra_tensor_info .fully_qualified_name = fqn
0 commit comments