@@ -596,19 +596,6 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
596596 args = args ,
597597 )
598598
599- # # Override dtype of the model as specified by the user args.
600- # if dtype_override:
601- # assert isinstance(
602- # dtype_override, DType
603- # ), "Override dtype needs to be of type <DType>"
604- # torch_dtype = dtype_override.to_torch_dtype()
605- # logging.info(f"model.to {torch_dtype}")
606- # edge_manager.model = edge_manager.model.to(dtype=torch_dtype)
607- # metadata_str=args.metadata,
608- # dtype_override=dtype_override,
609- # args=args,
610- # )
611-
612599 # Assumes the checkpoint has uniform dtype.
613600 checkpoint_dtype = next (edge_manager .model .parameters ()).dtype
614601 print (f"checkpoint dtype: { checkpoint_dtype } " )
@@ -619,14 +606,6 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
619606 )
620607 )
621608
622- quantized = torch .load ("/home/jackzhxng/torchrepos/executorch/fake_quantized_weights.pt" )
623- # Test run the model to trace.
624- edge_manager .model (
625- torch .tensor ([[2 , 3 , 4 ]], dtype = torch .long ),
626- {"input_pos" : torch .tensor ([0 ], dtype = torch .long )},
627- )
628- # torch.testing.assert_close()
629-
630609 # We want to do compute the actual ops in the precision of the dtype_override.
631610 def _set_precision_to_fp32 (module ):
632611 """
0 commit comments