diff --git a/.ci/docker/ci_commit_pins/pytorch.txt b/.ci/docker/ci_commit_pins/pytorch.txt index 0e9181ac55a..01567528b80 100644 --- a/.ci/docker/ci_commit_pins/pytorch.txt +++ b/.ci/docker/ci_commit_pins/pytorch.txt @@ -1 +1 @@ -aec9b2ab77389967ef39bb9c10662fd0fe3e185a +21a304b17ffa9288b0357633d00804a646bb8a15 diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index f9d408c1bae..a50e2732f15 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -100,11 +100,11 @@ def test_mv2_u55_BI(self): ) if common.is_option_enabled("corstone300"): tester.run_method_and_compare_outputs( - atol=1.0, qtol=1, inputs=self.model_inputs + atol=1.0, qtol=1, inputs=self.model_inputs, target_board="corstone-300" ) def test_mv2_u85_BI(self): - ( + tester = ( ArmTester( self.mv2, example_inputs=self.model_inputs, @@ -116,4 +116,9 @@ def test_mv2_u85_BI(self): .check(list(self.operators_after_quantization)) .partition() .to_executorch() + .serialize() ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + atol=1.0, qtol=1, inputs=self.model_inputs, target_board="corstone-320" + ) diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index cff8af11654..e3eeb187da3 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -137,16 +137,22 @@ def test_add_u55_BI(self, test_data: torch.Tensor): test_data, ) if common.is_option_enabled("corstone300"): - tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-300" + ) @parameterized.expand(Add.test_parameters) def test_add_u85_BI(self, test_data: torch.Tensor): test_data = (test_data,) - self._test_add_ethos_BI_pipeline( + tester = self._test_add_ethos_BI_pipeline( self.Add(), common.get_u85_compile_spec(permute_memory_to_nhwc=True), test_data, ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-320" + ) @parameterized.expand(Add2.test_parameters) def test_add2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): @@ -165,11 +171,17 @@ def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): self.Add2(), common.get_u55_compile_spec(), test_data ) if common.is_option_enabled("corstone300"): - tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-300" + ) @parameterized.expand(Add2.test_parameters) def test_add2_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): test_data = (operand1, operand2) - self._test_add_ethos_BI_pipeline( + tester = self._test_add_ethos_BI_pipeline( self.Add2(), common.get_u85_compile_spec(), test_data ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-320" + ) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 2935a2e13ef..0a0143e14c6 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -177,6 +177,7 @@ def __init__( self.qp_input: list[QuantizationParams] = None self.qp_output: QuantizationParams = None self.timeout = 120 + self.target_board: str = None self._has_init_run = False @@ -185,11 +186,17 @@ def init_run( exported_program: ExportedProgram, edge_program: ExportedProgram, is_quantized: bool, + target_board: str, ): + + if target_board not in ["corstone-300", "corstone-320"]: + raise RuntimeError(f"Unknown target board: {target_board}") + self.input_names = _get_input_names(edge_program) self.output_node = _get_output_node(exported_program) self.output_name = self.output_node.name self.is_quantized = is_quantized + self.target_board = target_board if is_quantized: self.qp_input = _get_input_quantization_params(exported_program) @@ -205,7 +212,7 @@ def init_run( def set_timeout(self, timeout: int): self.timeout = timeout - def run_corstone300( + def run_corstone( self, inputs: Tuple[torch.Tensor], ) -> list[torch.Tensor]: @@ -231,7 +238,7 @@ def run_corstone300( ) elf_path = os.path.join( "cmake-out", - "arm_semihosting_executor_runner_corstone-300", + f"arm_semihosting_executor_runner_{self.target_board}", "arm_executor_runner", ) assert os.path.exists( @@ -242,32 +249,66 @@ def run_corstone300( for input_path in input_paths: cmd_line += f" -i {input_path}" - command_args = [ - "FVP_Corstone_SSE-300_Ethos-U55", - "-C", - "ethosu.num_macs=128", - "-C", - "mps3_board.visualisation.disable-visualisation=1", - "-C", - "mps3_board.telnetterminal0.start_telnet=0", - "-C", - "mps3_board.uart0.out_file='-'", - "-C", - "cpu0.CFGITCMSZ=11", - "-C", - "cpu0.semihosting-enable=1", - "-C", - "cpu0.semihosting-stack_base=0", - "-C", - "cpu0.semihosting-heap_limit=0", - "-C", - f"cpu0.semihosting-cmd_line='{cmd_line}'", - "-a", - elf_path, - "--timelimit", - f"{self.timeout}", - ] - result = _run_cmd(command_args, check=False) + command_args = { + "corstone-300": [ + "FVP_Corstone_SSE-300_Ethos-U55", + "-C", + "ethosu.num_macs=128", + "-C", + "mps3_board.visualisation.disable-visualisation=1", + "-C", + "mps3_board.telnetterminal0.start_telnet=0", + "-C", + "mps3_board.uart0.out_file='-'", + "-C", + "cpu0.CFGITCMSZ=11", + "-C", + "cpu0.semihosting-enable=1", + "-C", + "cpu0.semihosting-stack_base=0", + "-C", + "cpu0.semihosting-heap_limit=0", + "-C", + f"cpu0.semihosting-cmd_line='{cmd_line}'", + "-a", + elf_path, + "--timelimit", + f"{self.timeout}", + ], + "corstone-320": [ + "FVP_Corstone_SSE-320", + "-C", + "mps4_board.subsystem.ethosu.num_macs=128", + "-C", + "mps4_board.visualisation.disable-visualisation=1", + "-C", + "mps4_board.telnetterminal0.start_telnet=0", + "-C", + "mps4_board.uart0.out_file='-'", + "-C", + "mps4_board.uart0.unbuffered_output=1", + "-C", + "mps4_board.uart0.shutdown_on_eot=1", + "-C", + "mps4_board.subsystem.cpu0.semihosting-enable=1", + "-C", + "mps4_board.subsystem.cpu0.semihosting-stack_base=0", + "-C", + "mps4_board.subsystem.cpu0.semihosting-heap_limit=0", + "-C", + f"mps4_board.subsystem.cpu0.semihosting-cmd_line='{cmd_line}'", + "-a", + elf_path, + "--timelimit", + f"{self.timeout}", + ], + } + + result = _run_cmd(command_args[self.target_board], check=False) + if result.returncode != 0: + raise RuntimeError( + f"Failed to run {command_args[self.target_board]}\nError: {result.stderr.decode()}" + ) result_stdout = result.stdout.decode() error_regex = r"(^[EF][: ].*$)|(^.*Hard fault.*$)|(^.*Assertion.*$)" @@ -276,10 +317,8 @@ def run_corstone300( # regex to check for error or fault messages in stdout from FVP if re.compile(error_regex, re.MULTILINE).search(result_stdout): raise RuntimeError( - f"Corstone simulation failed, log: \n {result_stdout}\n{result.stderr.decode()}" + f"Corstone simulation failed:\ncmd: {command_args[self.target_board]}\n, log: \n {result_stdout}\n{result.stderr.decode()}" ) - elif "E [" in result_stdout: - logger.error(result_stdout) tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32) output_shape = self.output_node.args[0][0].meta["val"].shape diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 7e8a1198ad0..eb52f4b2070 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -98,7 +98,7 @@ def __init__(self, runner_util: RunnerUtil, timeout: int = 1): self.runner.set_timeout(timeout) def run_artifact(self, inputs): - return self.runner.run_corstone300(inputs) + return self.runner.run_corstone(inputs) def dump_artifact(self, path_to_dump: Optional[str]): if not path_to_dump: @@ -226,6 +226,7 @@ def run_method_and_compare_outputs( self, inputs: Optional[Tuple[torch.Tensor]] = None, stage: Optional[str] = None, + target_board: Optional[str] = "corstone-300", num_runs=1, atol=1e-03, rtol=1e-03, @@ -260,7 +261,12 @@ def run_method_and_compare_outputs( edge_program = self.stages[ self.stage_name(tester.ToEdge) ].artifact.exported_program() - self.runner_util.init_run(exported_program, edge_program, is_quantized) + self.runner_util.init_run( + exported_program, + edge_program, + is_quantized, + target_board, + ) if is_quantized: reference_stage = self.stages[self.stage_name(tester.Quantize)] diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 0539d4f5e4b..3691cd0234d 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -967,6 +967,7 @@ class Inspector: def __init__( self, etdump_path: Optional[str] = None, + etdump_data: Optional[bytes] = None, etrecord: Optional[Union[ETRecord, str]] = None, source_time_scale: TimeScale = TimeScale.NS, target_time_scale: TimeScale = TimeScale.MS, @@ -980,11 +981,12 @@ def __init__( enable_module_hierarchy: bool = False, ) -> None: r""" - Initialize an `Inspector` instance with the underlying `EventBlock`\ s populated with data from the provided ETDump path + Initialize an `Inspector` instance with the underlying `EventBlock`\ s populated with data from the provided ETDump path or binary, and optional ETRecord path. Args: - etdump_path: Path to the ETDump file. + etdump_path: Path to the ETDump file. Either this parameter or etdump_data should be provided. + etdump_data: ETDump binary. Either this parameter or etdump_path should be provided. etrecord: Optional ETRecord object or path to the ETRecord file. source_time_scale: The time scale of the performance data retrieved from the runtime. The default time hook implentation in the runtime returns NS. target_time_scale: The target time scale to which the users want their performance data converted to. Defaults to MS. @@ -1025,8 +1027,13 @@ def __init__( else: raise TypeError("Unsupported ETRecord type") + if (etdump_path is None) == (etdump_data is None): + raise ValueError( + "Expecting exactly one of etdump_path or etdump_data to be specified." + ) + # Create EventBlocks from ETDump - etdump = gen_etdump_object(etdump_path=etdump_path) + etdump = gen_etdump_object(etdump_path=etdump_path, etdump_data=etdump_data) if debug_buffer_path is not None: with open(debug_buffer_path, "rb") as f: output_buffer = f.read() diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 5f04e2d0413..a2989c224e1 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -279,13 +279,20 @@ def _extract_debug_handles(graph: OperatorGraph): return debug_handle_to_op_node_map -def gen_etdump_object(etdump_path: Optional[str] = None) -> ETDumpFlatCC: +def gen_etdump_object( + etdump_path: Optional[str] = None, etdump_data: Optional[bytes] = None +) -> ETDumpFlatCC: # Gen event blocks from etdump - if etdump_path is None: - raise ValueError("Etdump_path must be specified.") - with open(etdump_path, "rb") as buff: - etdump = deserialize_from_etdump_flatcc(buff.read()) - return etdump + if etdump_data is None and etdump_path is not None: + with open(etdump_path, "rb") as buff: + etdump_data = buff.read() + + if etdump_data is None: + raise ValueError( + "Unable to get ETDump data. One and only one of etdump_path and etdump_data must be specified." + ) + + return deserialize_from_etdump_flatcc(etdump_data) def plot_metric(result: List[float], metric_name: str): diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index 34c96eef534..4b3f8075d8e 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -86,7 +86,9 @@ def test_inspector_constructor(self): # Assert that expected functions are called mock_parse_etrecord.assert_called_once_with(etrecord_path=ETRECORD_PATH) - mock_gen_etdump.assert_called_once_with(etdump_path=ETDUMP_PATH) + mock_gen_etdump.assert_called_once_with( + etdump_path=ETDUMP_PATH, etdump_data=None + ) mock_gen_from_etdump.assert_called_once() # Because we mocked parse_etrecord() to return None, this method shouldn't be called mock_gen_graphs_from_etrecord.assert_not_called() diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 73d552cb268..20224b9e9c3 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -45,20 +45,28 @@ function verify_md5() { script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) if [[ "${ARCH}" == "x86_64" ]]; then - # FVP - fvp_url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/FVP_Corstone_SSE-300_11.22_20_Linux64.tgz?rev=018659bd574f4e7b95fa647e7836ccf4&hash=22A79103C6FA5FFA7AFF3BE0447F3FF9" - fvp_model_dir="Linux64_GCC-9.3" - fvp_md5_checksum="98e93b949d0fbac977292d8668d34523" + # FVPs + corstone300_url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/FVP_Corstone_SSE-300_11.22_20_Linux64.tgz?rev=018659bd574f4e7b95fa647e7836ccf4&hash=22A79103C6FA5FFA7AFF3BE0447F3FF9" + corstone300_model_dir="Linux64_GCC-9.3" + corstone300_md5_checksum="98e93b949d0fbac977292d8668d34523" + + corstone320_url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-320/FVP_Corstone_SSE-320_11.27_25_Linux64.tgz?rev=a507bffc219a4d5792f1192ab7002d89&hash=D9A824AA8227D2E679C9B9787FF4E8B6FBE3D7C6" + corstone320_model_dir="Linux64_GCC-9.3" + corstone320_md5_checksum="3deb3c68f9b2d145833f15374203514d" # toochain toolchain_url="https://armkeil.blob.core.windows.net/developer/Files/downloads/gnu/12.3.rel1/binrel/arm-gnu-toolchain-12.3.rel1-x86_64-arm-none-eabi.tar.xz" toolchain_dir="arm-gnu-toolchain-12.3.rel1-x86_64-arm-none-eabi" toolchain_md5_checksum="00ebb1b70b1f88906c61206457eacb61" elif [[ "${ARCH}" == "aarch64" ]] || [[ "${ARCH}" == "arm64" ]]; then - # FVP - fvp_url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/FVP_Corstone_SSE-300_11.22_20_Linux64_armv8l.tgz?rev=9cc6e9a32bb947ca9b21fa162144cb01&hash=7657A4CF27D42E892E3F08D452AAB073" - fvp_model_dir="Linux64_armv8l_GCC-9.3" - fvp_md5_checksum="cbbabbe39b07939cff7a3738e1492ef1" + # FVPs + corstone300_url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-300/FVP_Corstone_SSE-300_11.22_20_Linux64_armv8l.tgz?rev=9cc6e9a32bb947ca9b21fa162144cb01&hash=7657A4CF27D42E892E3F08D452AAB073" + corstone300_model_dir="Linux64_armv8l_GCC-9.3" + corstone300_md5_checksum="cbbabbe39b07939cff7a3738e1492ef1" + + corstone320_url="https://developer.arm.com/-/media/Arm%20Developer%20Community/Downloads/OSS/FVP/Corstone-320/FVP_Corstone_SSE-320_11.27_25_Linux64_armv8l.tgz?rev=b6ebe0923cb84f739e017385fd3c333c&hash=8965C4B98E2FF7F792A099B08831FE3CB6120493" + corstone320_model_dir="Linux64_armv8l_GCC-9.3" + corstone320_md5_checksum="3889f1d80a6d9861ea4aa6f1c88dd0ae" # toochain if [[ "${OS}" == "Darwin" ]]; then @@ -105,26 +113,50 @@ function setup_fvp() { fi # Download and install the Corstone 300 FVP simulator platform - cd "${root_dir}" - if [[ ! -e FVP_cs300.tgz ]]; then - echo "[${FUNCNAME[0]}] Downloading FVP ..." - curl --output FVP_cs300.tgz "${fvp_url}" - verify_md5 ${fvp_md5_checksum} FVP_cs300.tgz - fi - - echo "[${FUNCNAME[0]}] Installing FVP ..." - rm -rf FVP - mkdir -p FVP - cd FVP - tar xf ../FVP_cs300.tgz - ./FVP_Corstone_SSE-300.sh --i-agree-to-the-contained-eula --force --destination ./ --quiet --no-interactive - - fvp_bin_path="$(cd models/${fvp_model_dir} && pwd)" - export PATH=${PATH}:${fvp_bin_path} - - hash FVP_Corstone_SSE-300_Ethos-U55 - echo "export PATH=\${PATH}:${fvp_bin_path}" >> ${setup_path_script} - + fvps=("corstone300" "corstone320") + + for fvp in "${fvps[@]}"; do + cd "${root_dir}" + if [[ ! -e "FVP_${fvp}.tgz" ]]; then + echo "[${FUNCNAME[0]}] Downloading FVP ${fvp}..." + url_variable=${fvp}_url + fvp_url=${!url_variable} + curl --output "FVP_${fvp}.tgz" "${fvp_url}" + md5_variable=${fvp}_md5_checksum + fvp_md5_checksum=${!md5_variable} + verify_md5 ${fvp_md5_checksum} FVP_${fvp}.tgz + fi + + echo "[${FUNCNAME[0]}] Installing FVP ${fvp}..." + rm -rf FVP-${fvp} + mkdir -p FVP-${fvp} + cd FVP-${fvp} + tar xf ../FVP_${fvp}.tgz + + # Install the FVP + case ${fvp} in + corstone300) + ./FVP_Corstone_SSE-300.sh --i-agree-to-the-contained-eula --force --destination ./ --quiet --no-interactive + ;; + corstone320) + ./FVP_Corstone_SSE-320.sh --i-agree-to-the-contained-eula --force --destination ./ --quiet --no-interactive + ;; + *) + echo "[${FUNCNAME[0]}] Error: Unknown FVP model ${fvp}. Exiting." + exit 1 + ;; + esac + + model_dir_variable=${fvp}_model_dir + fvp_model_dir=${!model_dir_variable} + fvp_bin_path="$(cd models/${fvp_model_dir} && pwd)" + export PATH=${PATH}:${fvp_bin_path} + + echo "export PATH=\${PATH}:${fvp_bin_path}" >> ${setup_path_script} + done + + # Fixup for Corstone-320 python dependency + echo "export LD_LIBRARY_PATH=${root_dir}/FVP-corstone320/python/lib/" >> ${setup_path_script} } function setup_toolchain() { diff --git a/examples/models/llama2/README.md b/examples/models/llama2/README.md index 138044d5342..bcca1b82ba4 100644 --- a/examples/models/llama2/README.md +++ b/examples/models/llama2/README.md @@ -162,13 +162,13 @@ python -m examples.models.llama2.export_llama \ --params "${LLAMA_PARAMS:?}" \ --use_sdpa_with_kv_cache \ -X \ - --spin_qmode 8da4w_output_8da8w \ - --spin_group_size 32 \ + --preq_mode 8da4w_output_8da8w \ + --preq_group_size 32 \ --max_seq_length 2048 \ --output_name "llama3_2.pte" \ -kv \ -d fp32 \ - --spin_embedding_quantize 8,0 \ + --preq_embedding_quantize 8,0 \ --use_spin_quant native \ --metadata '{"append_eos_to_prompt": 0, "get_bos_id":128000, "get_eos_ids":[128009, 128001], "get_n_bos": 0, "get_n_eos": 0}' ``` diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index 1b641c1d0fd..40822e574c3 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -80,6 +80,7 @@ runtime.python_library( "export_llama_lib.py", "model.py", "source_transformation/apply_spin_quant_r1_r2.py", + "source_transformation/pre_quantization.py", "source_transformation/prune_output.py", "source_transformation/quantize.py", "source_transformation/quantized_kv_cache.py", diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 2b43274760a..a39bb048200 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -391,25 +391,25 @@ def build_args_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--spin_qmode", + "--preq_mode", type=str, default=None, choices=["8da4w", "8da4w_output_8da8w"], - help="Quantization mode for SpinQuant. Only support 8da4w and 8da4w_output_8da8w right now.", + help="Quantization mode used for pre-quantized checkpoint. Only support 8da4w and 8da4w_output_8da8w right now.", ) parser.add_argument( - "--spin_group_size", + "--preq_group_size", type=int, default=32, - help="group_size for SpinQuant weight quantization", + help="group_size for pre-quantized checkpoint weight quantization", ) parser.add_argument( - "--spin_embedding_quantize", + "--preq_embedding_quantize", default="8,0", type=str, - help="type of embedding quantization for SpinQuant, ',', e.g., '8,1024'.", + help="type of embedding quantization for pre-quantized checkpoint, ',', e.g., '8,1024'.", ) parser.add_argument( diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index c48fa98d576..a4081d1bd57 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -191,20 +191,20 @@ def __init__(self, **kwargs): ) elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant: print("Using SPIN quantization.") - assert hasattr(self.args, "spin_qmode"), "spin_qmode must be specified" - assert self.args.spin_qmode in [ + assert hasattr(self.args, "preq_mode"), "preq_mode must be specified" + assert self.args.preq_mode in [ "8da4w", "8da4w_output_8da8w", - ], f"Quantization mode {self.args.spin_qmode} is not compatible with SpinQuant." + ], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant." assert hasattr( - self.args, "spin_group_size" - ), "spin_group_size must be specified" + self.args, "preq_group_size" + ), "preq_group_size must be specified" assert hasattr( self.args, "dtype_override" ), "dtype_override must be specified" - from .source_transformation.spin_quant import ( - sanitize_checkpoint_from_spinquant, - transform_linear_for_spinquant, + from .source_transformation.pre_quantization import ( + sanitize_checkpoint_from_pre_quantization, + transform_linear_for_pre_quantization, ) mapping = { @@ -214,31 +214,31 @@ def __init__(self, **kwargs): } # Transform the output layer first if needed. - if self.args.spin_qmode == "8da4w_output_8da8w": - from .source_transformation.spin_quant import ( - transform_output_linear_for_spinquant, + if self.args.preq_mode == "8da4w_output_8da8w": + from .source_transformation.pre_quantization import ( + transform_output_linear_for_pre_quantization, ) - self.model_ = transform_output_linear_for_spinquant( + self.model_ = transform_output_linear_for_pre_quantization( module=self.model_, checkpoint=checkpoint, dtype=mapping[self.args.dtype_override], ) - self.model_ = transform_linear_for_spinquant( + self.model_ = transform_linear_for_pre_quantization( self.model_, checkpoint, - self.args.spin_group_size, + self.args.preq_group_size, mapping[self.args.dtype_override], ) embedding_bit_width, embedding_group_size = None, None - if hasattr(self.args, "spin_embedding_quantize"): + if hasattr(self.args, "preq_embedding_quantize"): embedding_bit_width, embedding_group_size = ( - self.args.spin_embedding_quantize.split(",") + self.args.preq_embedding_quantize.split(",") ) - from .source_transformation.spin_quant import ( - transform_embedding_for_spinquant, + from .source_transformation.pre_quantization import ( + transform_embedding_for_pre_quantization, ) if ( @@ -250,7 +250,7 @@ def __init__(self, **kwargs): else: embedding_group_size = int(embedding_group_size) - self.model_ = transform_embedding_for_spinquant( + self.model_ = transform_embedding_for_pre_quantization( self.model_, checkpoint, mapping[self.args.dtype_override], @@ -258,7 +258,7 @@ def __init__(self, **kwargs): embedding_group_size, ) - sanitize_checkpoint_from_spinquant(checkpoint) + sanitize_checkpoint_from_pre_quantization(checkpoint) # assign=True: load params/buffers by assignment instead of performing an in-place copy. # Because we are using device="meta", tensors do not have memory associated with them diff --git a/examples/models/llama2/source_transformation/pre_quantization.py b/examples/models/llama2/source_transformation/pre_quantization.py new file mode 100644 index 00000000000..38937c5ab4e --- /dev/null +++ b/examples/models/llama2/source_transformation/pre_quantization.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# Helper functions for tranforming the model to be able to load pre-quantized checkpoints. + +from typing import Any, Optional + +import torch +from torch import nn + +from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + +from .quantize import Int8DynActInt8WeightLinear, QuantizedGroupEmbedding + + +def _replace_linear_with_linear_8da4w_for_pre_quantization( + module: torch.nn.Module, + checkpoint: Any, + group_size: int, + precision: torch.dtype, + scales_precision: torch.dtype, +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + # Only replace linear layers where the checkpoint contains explicit scales + scales_key = f"{cur_fqn}.scales" + if isinstance(child, nn.Linear) and scales_key in checkpoint: + assert _check_linear_int4_k(child.in_features, group_size) + assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 + assert checkpoint[scales_key].dtype == scales_precision + return True + return False + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + new_linear = Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + device=child.weight.device, + groupsize=group_size, + precision=precision, + scales_precision=scales_precision, + ) + return new_linear + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def transform_linear_for_pre_quantization( + module: torch.nn.Module, + checkpoint: Any, + group_size: int, + dtype: torch.dtype, +) -> torch.nn.Module: + """ + Transform the model to be able to load pre-quantized checkpoints that + are quantized with the given group size and quantization mode for + linear layers. + """ + + if group_size not in [32, 64, 128, 256]: + raise ValueError( + f"Group size {group_size} is not supported for pre-quantized checkpoint." + ) + _replace_linear_with_linear_8da4w_for_pre_quantization( + module, + checkpoint, + group_size, + dtype, + dtype, + ) + return module + + +def _replace_output_linear_with_linear_int8_for_pre_quantization( + module: torch.nn.Module, + checkpoint: Any, + dtype: torch.dtype, +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + scales_key = f"{cur_fqn}.scales" + if ( + isinstance(child, nn.Linear) + and scales_key in checkpoint + and "output" in cur_fqn + ): + assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 + assert checkpoint[scales_key].dtype == dtype + return True + return False + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + new_linear = Int8DynActInt8WeightLinear( + device=child.weight.device, + in_features=child.in_features, + out_features=child.out_features, + precision=dtype, + bias=False, + ) + return new_linear + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def transform_output_linear_for_pre_quantization( + module: torch.nn.Module, + checkpoint: Any, + dtype: torch.dtype, +) -> torch.nn.Module: + """ + Transform the model to be able to load pre-quantized checkpoints that + has the output layer quantized per-channel. + """ + _replace_output_linear_with_linear_int8_for_pre_quantization( + module, + checkpoint, + dtype, + ) + return module + + +def _replace_embedding_with_quantized_group_embedding_for_pre_quantization( + module: torch.nn.Module, + checkpoint: Any, + dtype: torch.dtype, + bit_width: int, + group_size: Optional[int] = None, +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + # Only replace embedding layers where the checkpoint contains explicit scales + scales_key = f"{cur_fqn}.scales" + if isinstance(child, nn.Embedding) and scales_key in checkpoint: + assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 + assert checkpoint[scales_key].dtype == torch.float32 + return True + return False + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + new_embedding = QuantizedGroupEmbedding( + device=child.weight.device, + vocab_size=child.weight.shape[0], + embedding_dim=child.weight.shape[1], + group_size=group_size, + dtype=dtype, + packed=False, # TODO(lunwenh): support packed embedding for pre-quantized + ) + return new_embedding + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def transform_embedding_for_pre_quantization( + module: torch.nn.Module, + checkpoint: Any, + dtype: torch.dtype, + bit_width: int, + group_size: Optional[int] = None, +) -> torch.nn.Module: + """ + Transform the model to be able to load pre-quantized checkpoints that + are quantized with the given bit_width and group size for embedding. + """ + if group_size is not None and group_size not in [0, 32, 64, 128, 256]: + raise ValueError( + f"Group size {group_size} is not supported for pre-quantized checkpoint." + ) + _replace_embedding_with_quantized_group_embedding_for_pre_quantization( + module, + checkpoint, + dtype, + bit_width, + group_size, + ) + return module + + +def sanitize_checkpoint_from_pre_quantization( + checkpoint: Any, +): + """ + Sanitize the pre-quantized checkpoint. + - Converts all tensors to contiguous format + - Squeeze all tensors + """ + for k, v in checkpoint.items(): + checkpoint[k] = torch.squeeze(v.contiguous()) diff --git a/examples/models/llama2/source_transformation/spin_quant.py b/examples/models/llama2/source_transformation/spin_quant.py index f579e1352eb..f544e9e1f6e 100644 --- a/examples/models/llama2/source_transformation/spin_quant.py +++ b/examples/models/llama2/source_transformation/spin_quant.py @@ -9,7 +9,6 @@ # Helper functions for tranforming the model to be able to run SpinQuant. # See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant. -from typing import Any, Optional import torch @@ -17,10 +16,6 @@ from executorch.examples.models.llama2.llama_transformer import FeedForward from torch import nn -from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear -from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter - -from .quantize import Int8DynActInt8WeightLinear, QuantizedGroupEmbedding def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module): @@ -91,171 +86,3 @@ def inject_fast_hadamard_transform_native_for_spin_quant( ) -> torch.nn.Module: _inject_fast_hadamard_transform_native_for_spin_quant(module) return module - - -def _replace_linear_with_linear_8da4w_for_spin_quant( - module: torch.nn.Module, - checkpoint: Any, - group_size: int, - precision: torch.dtype, - scales_precision: torch.dtype, -): - def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: - # Only replace linear layers where the checkpoint contains explicit scales - scales_key = f"{cur_fqn}.scales" - if isinstance(child, nn.Linear) and scales_key in checkpoint: - assert _check_linear_int4_k(child.in_features, group_size) - assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 - assert checkpoint[scales_key].dtype == scales_precision - return True - return False - - def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: - new_linear = Int8DynActInt4WeightLinear( - child.in_features, - child.out_features, - bias=False, - device=child.weight.device, - groupsize=group_size, - precision=precision, - scales_precision=scales_precision, - ) - return new_linear - - _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) - - -def transform_linear_for_spinquant( - module: torch.nn.Module, - checkpoint: Any, - group_size: int, - dtype: torch.dtype, -) -> torch.nn.Module: - """ - Transform the model to be able to load SpinQuant checkpoints that - are quantized with the given group size and quantization mode for - linear layers. - """ - - if group_size not in [32, 64, 128, 256]: - raise ValueError(f"Group size {group_size} is not supported for SpinQuant.") - _replace_linear_with_linear_8da4w_for_spin_quant( - module, - checkpoint, - group_size, - dtype, - dtype, - ) - return module - - -def _replace_output_linear_with_linear_int8_for_spinquant( - module: torch.nn.Module, - checkpoint: Any, - dtype: torch.dtype, -): - def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: - scales_key = f"{cur_fqn}.scales" - if ( - isinstance(child, nn.Linear) - and scales_key in checkpoint - and "output" in cur_fqn - ): - assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 - assert checkpoint[scales_key].dtype == dtype - return True - return False - - def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: - new_linear = Int8DynActInt8WeightLinear( - device=child.weight.device, - in_features=child.in_features, - out_features=child.out_features, - precision=dtype, - bias=False, - ) - return new_linear - - _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) - - -def transform_output_linear_for_spinquant( - module: torch.nn.Module, - checkpoint: Any, - dtype: torch.dtype, -) -> torch.nn.Module: - """ - Transform the model to be able to load SpinQuant checkpoints that - has the output layer quantized per-channel. - """ - _replace_output_linear_with_linear_int8_for_spinquant( - module, - checkpoint, - dtype, - ) - return module - - -def _replace_embedding_with_quantized_group_embedding_for_spinquant( - module: torch.nn.Module, - checkpoint: Any, - dtype: torch.dtype, - bit_width: int, - group_size: Optional[int] = None, -): - def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: - # Only replace embedding layers where the checkpoint contains explicit scales - scales_key = f"{cur_fqn}.scales" - if isinstance(child, nn.Embedding) and scales_key in checkpoint: - assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 - assert checkpoint[scales_key].dtype == torch.float32 - return True - return False - - def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: - new_embedding = QuantizedGroupEmbedding( - device=child.weight.device, - vocab_size=child.weight.shape[0], - embedding_dim=child.weight.shape[1], - group_size=group_size, - dtype=dtype, - packed=False, # TODO(lunwenh): support packed embedding for SpinQuant - ) - return new_embedding - - _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) - - -def transform_embedding_for_spinquant( - module: torch.nn.Module, - checkpoint: Any, - dtype: torch.dtype, - bit_width: int, - group_size: Optional[int] = None, -) -> torch.nn.Module: - """ - Transform the model to be able to load SpinQuant checkpoints that - are quantized with the given bit_width and group size for embedding. - """ - if group_size is not None and group_size not in [0, 32, 64, 128, 256]: - raise ValueError(f"Group size {group_size} is not supported for SpinQuant.") - _replace_embedding_with_quantized_group_embedding_for_spinquant( - module, - checkpoint, - dtype, - bit_width, - group_size, - ) - return module - - -def sanitize_checkpoint_from_spinquant( - checkpoint: Any, -): - """ - Sanitize the SpinQuant checkpoint. - - Converts all tensors to contiguous format - - Squeeze all tensors - """ - for k, v in checkpoint.items(): - checkpoint[k] = torch.squeeze(v.contiguous()) diff --git a/examples/models/llama2/tests/TARGETS b/examples/models/llama2/tests/TARGETS index 76981d8f317..2e4dcf7d1f6 100644 --- a/examples/models/llama2/tests/TARGETS +++ b/examples/models/llama2/tests/TARGETS @@ -15,9 +15,9 @@ python_unittest( ) python_unittest( - name = "test_spinquant_transforms", + name = "test_pre_quantization_transforms", srcs = [ - "test_spinquant_transforms.py", + "test_pre_quantization_transforms.py", ], deps = [ "//caffe2:torch", diff --git a/examples/models/llama2/tests/test_spinquant_transforms.py b/examples/models/llama2/tests/test_pre_quantization_transforms.py similarity index 86% rename from examples/models/llama2/tests/test_spinquant_transforms.py rename to examples/models/llama2/tests/test_pre_quantization_transforms.py index 4f6306814b6..59cec2e72ab 100644 --- a/examples/models/llama2/tests/test_spinquant_transforms.py +++ b/examples/models/llama2/tests/test_pre_quantization_transforms.py @@ -8,19 +8,19 @@ import torch from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer +from executorch.examples.models.llama2.source_transformation.pre_quantization import ( + sanitize_checkpoint_from_pre_quantization, + transform_embedding_for_pre_quantization, + transform_linear_for_pre_quantization, + transform_output_linear_for_pre_quantization, +) from executorch.examples.models.llama2.source_transformation.quantize import ( dynamically_quantize_per_channel, ) -from executorch.examples.models.llama2.source_transformation.spin_quant import ( - sanitize_checkpoint_from_spinquant, - transform_embedding_for_spinquant, - transform_linear_for_spinquant, - transform_output_linear_for_spinquant, -) from torchao.quantization.utils import group_quantize_tensor_symmetric -class SpinQuantTests(unittest.TestCase): +class PreQuantizationTests(unittest.TestCase): def _prepare_dummy_model(self) -> Transformer: model_args = ModelArgs( @@ -42,7 +42,7 @@ def _prepare_dummy_model(self) -> Transformer: return model - def test_transform_linear_for_spinquant(self): + def test_transform_linear_for_pre_quantization(self): # Step 1: Create llama class with dummy weights model = self._prepare_dummy_model() @@ -69,14 +69,13 @@ def test_transform_linear_for_spinquant(self): # Step 3: # Transform the model so that it is compatible with the new checkpoint - transform_linear_for_spinquant( + transform_linear_for_pre_quantization( model, checkpoint, 32, - "8da4w", torch.float32, ) - sanitize_checkpoint_from_spinquant(checkpoint) + sanitize_checkpoint_from_pre_quantization(checkpoint) model.load_state_dict( checkpoint, @@ -91,7 +90,7 @@ def test_transform_linear_for_spinquant(self): # have to iterate over the keys. self.assertTrue(torch.allclose(new_checkpoint[k], v)) - def test_transform_output_linear_for_spinquant(self): + def test_transform_output_linear_for_pre_quantization(self): # Step 1: Create llama class with dummy weights model = self._prepare_dummy_model() checkpoint = model.state_dict() @@ -114,12 +113,12 @@ def test_transform_output_linear_for_spinquant(self): # Step 3: # Transform the model so that it is compatible with the new checkpoint - transform_output_linear_for_spinquant( + transform_output_linear_for_pre_quantization( model, checkpoint, torch.float32, ) - sanitize_checkpoint_from_spinquant(checkpoint) + sanitize_checkpoint_from_pre_quantization(checkpoint) model.load_state_dict( checkpoint, @@ -134,7 +133,7 @@ def test_transform_output_linear_for_spinquant(self): # have to iterate over the keys. self.assertTrue(torch.allclose(new_checkpoint[k], v)) - def test_transform_embedding_for_spinquant(self): + def test_transform_embedding_for_pre_quantization(self): # Step 1: Create llama class with dummy weights model = self._prepare_dummy_model() @@ -162,14 +161,14 @@ def test_transform_embedding_for_spinquant(self): # Step 3: # Transform the model so that it is compatible with the new checkpoint - transform_embedding_for_spinquant( + transform_embedding_for_pre_quantization( model, checkpoint, torch.float32, n_bit, group_size, ) - sanitize_checkpoint_from_spinquant(checkpoint) + sanitize_checkpoint_from_pre_quantization(checkpoint) model.load_state_dict( checkpoint, diff --git a/install_requirements.py b/install_requirements.py index f169c52257c..2cdb4184ead 100644 --- a/install_requirements.py +++ b/install_requirements.py @@ -94,14 +94,14 @@ def python_is_compatible(): # NOTE: If a newly-fetched version of the executorch repo changes the value of # NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -NIGHTLY_VERSION = "dev20240912" +NIGHTLY_VERSION = "dev20240925" # The pip repository that hosts nightly torch packages. TORCH_NIGHTLY_URL = "https://download.pytorch.org/whl/nightly/cpu" # pip packages needed by exir. EXIR_REQUIREMENTS = [ - f"torch==2.5.0.{NIGHTLY_VERSION}", + f"torch==2.6.0.{NIGHTLY_VERSION}", f"torchvision==0.20.0.{NIGHTLY_VERSION}", # For testing. "typing-extensions", ] diff --git a/pytest.ini b/pytest.ini index ec4b381ac59..ecd58ea07e4 100644 --- a/pytest.ini +++ b/pytest.ini @@ -43,7 +43,7 @@ addopts = --ignore=backends/xnnpack/test/ops/linear.py --ignore=backends/xnnpack/test/models/llama2_et_example.py # T200992559: Add torchao to ET as core dependency - --ignore=examples/models/llama2/tests/test_spinquant_transforms.py + --ignore=examples/models/llama2/tests/test_pre_quantization_transforms.py --ignore=exir/backend/test/demos --ignore=exir/backend/test/test_backends.py --ignore=exir/backend/test/test_backends_lifted.py