diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index ff0277c2056..dcd2afa7a13 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -28,3 +28,6 @@ matplotlib>=3.9.4 myst-parser==0.18.1 sphinx_design==0.4.1 sphinx-copybutton==0.5.0 + +# script unit test requirements +yaspin==3.1.0 diff --git a/.ci/scripts/benchmark_tooling/README.md b/.ci/scripts/benchmark_tooling/README.md new file mode 100644 index 00000000000..8767bdda6f9 --- /dev/null +++ b/.ci/scripts/benchmark_tooling/README.md @@ -0,0 +1,172 @@ +# Executorch Benchmark Tooling + +A library providing tools for fetching, processing, and analyzing ExecutorchBenchmark data from the HUD Open API. This tooling helps compare performance metrics between private and public devices with identical settings. + +## Table of Contents + +- [Overview](#overview) +- [Installation](#installation) +- [Tools](#tools) + - [get_benchmark_analysis_data.py](#get_benchmark_analysis_datapy) + - [Quick Start](#quick-start) + - [Command Line Options](#command-line-options) + - [Example Usage](#example-usage) + - [Working with Output Files](#working-with-output-files-csv-and-excel) + - [Python API Usage](#python-api-usage) +- [Running Unit Tests](#running-unit-tests) + +## Overview + +The Executorch Benchmark Tooling provides a suite of utilities designed to: + +- Fetch benchmark data from HUD Open API for specified time ranges +- Clean and process data by filtering out failures +- Compare metrics between private and public devices with matching configurations +- Generate analysis reports in various formats (CSV, Excel, JSON) +- Support filtering by device pools, backends, and models + +This tooling is particularly useful for performance analysis, regression testing, and cross-device comparisons. + +## Installation + +Install dependencies: + +```bash +pip install -r requirements.txt +``` + +## Tools + +### get_benchmark_analysis_data.py + +This script is mainly used to generate analysis data comparing private devices with public devices using the same settings. + +It fetches benchmark data from HUD Open API for a specified time range, cleans the data by removing entries with FAILURE indicators, and retrieves all private device metrics along with equivalent public device metrics based on matching [model, backend, device_pool_names, arch] configurations. Users can filter the data by specifying private device_pool_names, backends, and models. + +#### Quick Start + +```bash +# generate excel sheets for all private devices with public devices using the same settings +python3 .ci/scripts/benchmark_tooling/get_benchmark_analysis_data.py \ + --startTime "2025-06-11T00:00:00" \ + --endTime "2025-06-17T18:00:00" \ + --outputType "excel" + +# generate the benchmark stability analysis +python3 .ci/scripts/benchmark_tooling/analyze_benchmark_stability.py \ +--primary-file private.xlsx \ +--reference-file public.xlsx +``` + +#### Command Line Options + +##### Basic Options: +- `--startTime`: Start time in ISO format (e.g., "2025-06-11T00:00:00") (required) +- `--endTime`: End time in ISO format (e.g., "2025-06-17T18:00:00") (required) +- `--env`: Choose environment ("local" or "prod", default: "prod") +- `--no-silent`: Show processing logs (default: only show results & minimum logging) + +##### Output Options: +- `--outputType`: Choose output format (default: "print") + - `print`: Display results in console + - `json`: Generate JSON file + - `df`: Display results in DataFrame format: `{'private': List[{'groupInfo':Dict,'df': DF},...],'public':List[{'groupInfo':Dict,'df': DF}]` + - `excel`: Generate Excel files with multiple sheets, the field in first row and first column contains the JSON string of the raw metadata + - `csv`: Generate CSV files in separate folders, the field in first row and first column contains the JSON string of the raw metadata +- `--outputDir`: Directory to save output files (default: current directory) + +##### Filtering Options: + +- `--device-pools`: Filter by private device pool names (e.g., "samsung-galaxy-s22-5g", "samsung-galaxy-s22plus-5g") +- `--backends`: Filter by specific backend names (e.g.,"xnnpack_q8") +- `--models`: Filter by specific model names (e.g., "mv3", "meta-llama-llama-3.2-1b-instruct-qlora-int4-eo8") + +#### Example Usage + +Filter by multiple private device pools and models: +```bash +# This fetches all private table data for models 'llama-3.2-1B' and 'mv3' +python3 get_benchmark_analysis_data.py \ + --startTime "2025-06-01T00:00:00" \ + --endTime "2025-06-11T00:00:00" \ + --device-pools 'apple_iphone_15_private' 'samsung_s22_private' \ + --models 'meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8' 'mv3' +``` + +Filter by specific device pool and models: +```bash +# This fetches all private iPhone table data for models 'llama-3.2-1B' and 'mv3', +# and associated public iPhone data +python3 get_benchmark_analysis_data.py \ + --startTime "2025-06-01T00:00:00" \ + --endTime "2025-06-11T00:00:00" \ + --device-pools 'apple_iphone_15_private' \ + --models 'meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8' 'mv3' +``` + +#### Working with Output Files CSV and Excel + +You can use methods in `common.py` to convert the file data back to DataFrame format. These methods read the first row in CSV/Excel files and return results with the format `list of {"groupInfo":DICT, "df":df.Dataframe{}}`. + +```python +import logging +logging.basicConfig(level=logging.INFO) +from .ci.scripts.benchmark_tooling.common import read_all_csv_with_metadata, read_excel_with_json_header + +# For CSV files (assuming the 'private' folder is in the current directory) +folder_path = './private' +res = read_all_csv_with_metadata(folder_path) +logging.info(res) + +# For Excel files (assuming the Excel file is in the current directory) +file_path = "./private.xlsx" +res = read_excel_with_json_header(file_path) +logging.info(res) +``` + +#### Python API Usage + +To use the benchmark fetcher in your own scripts: + +```python +from .ci.scripts.benchmark_tooling.get_benchmark_analysis_data import ExecutorchBenchmarkFetcher + +# Initialize the fetcher +fetcher = ExecutorchBenchmarkFetcher(env="prod", disable_logging=False) + +# Fetch data for a specific time range +fetcher.run( + start_time="2025-06-11T00:00:00", + end_time="2025-06-17T18:00:00" +) + +# Get results in different formats +# As DataFrames +df_results = fetcher.to_df() + +# Export to Excel +fetcher.to_excel(output_dir="./results") + +# Export to CSV +fetcher.to_csv(output_dir="./results") + +# Export to JSON +json_path = fetcher.to_json(output_dir="./results") + +# Get raw dictionary results +dict_results = fetcher.to_dict() + +# Use the output_data method for flexible output +results = fetcher.output_data(output_type="excel", output_dir="./results") +``` + +## Running Unit Tests + +The benchmark tooling includes unit tests to ensure functionality. + +### Using pytest for unit tests + +```bash +# From the executorch root directory +pytest -c /dev/null .ci/scripts/tests/test_get_benchmark_analysis_data.py +``` diff --git a/.ci/scripts/benchmark_tooling/__init__.py b/.ci/scripts/benchmark_tooling/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/.ci/scripts/analyze_benchmark_stability.py b/.ci/scripts/benchmark_tooling/analyze_benchmark_stability.py similarity index 88% rename from .ci/scripts/analyze_benchmark_stability.py rename to .ci/scripts/benchmark_tooling/analyze_benchmark_stability.py index 47f984b7ce3..64e4b05df86 100644 --- a/.ci/scripts/analyze_benchmark_stability.py +++ b/.ci/scripts/benchmark_tooling/analyze_benchmark_stability.py @@ -1,10 +1,10 @@ import argparse import os -import re import matplotlib.pyplot as plt import numpy as np import pandas as pd +from common import read_excel_with_json_header from tabulate import tabulate @@ -15,40 +15,23 @@ def print_section_header(title): print("=" * 100 + "\n") -def normalize_tab_name(name): +def normalize_name(name): """Normalize tab name for better matching""" # Convert to lowercase and remove spaces - return name.lower().replace(" ", "") + return name.lower().replace(" ", "").replace("(private)", "") -def parse_model_device(sheet_name): - """Extract model and device from sheet name using the 'model+device' pattern""" - parts = sheet_name.split("+", 1) - if len(parts) < 2: - return sheet_name, "Unknown" - return parts[0], parts[1] - - -def extract_model_device_os(sheet_name): - """ - Extract model, device, and OS from sheet name - Format expected: model+device_osname - Returns: (model, device_base, os_version) - """ - model, device_full = parse_model_device(sheet_name) - - # Use regex to separate device base name from OS version - # Pattern looks for device name followed by underscore or android/ios - match = re.match(r"(.*?)(android|ios|_)(.*)", device_full, re.IGNORECASE) - - if match: - device_base = match.group(1).rstrip("_") - os_name = match.group(2) - os_version = match.group(3) - return model, device_base, f"{os_name}{os_version}" - else: - # If no OS version found, return the device as is with empty OS - return model, device_full, "" +def parse_model_device_config(config): + """Extract model and device from config""" + model = config.get("model", "") + backend = config.get("backend", "") + full_model = f"{model}({backend})" if backend else model + base_device = config.get("device", "") + os_version = config.get("arch", "") + full_device = f"{base_device}({os_version})" if os_version else base_device + if not base_device: + return full_model, "unkown", "unknown", "" + return full_model, full_device, base_device, os_version def is_matching_dataset(primary_sheet, reference_sheet): @@ -56,10 +39,20 @@ def is_matching_dataset(primary_sheet, reference_sheet): Check if two datasets match for comparison based on model and device Allows different OS versions for the same device """ - primary_model, primary_device, primary_os = extract_model_device_os(primary_sheet) - reference_model, reference_device, reference_os = extract_model_device_os( - reference_sheet - ) + primary_model = normalize_name(primary_sheet.get("model", "")) + primary_device = normalize_name(primary_sheet.get("base_device", "")) + # primary_os = normalize_name(primary_sheet.get("os_version", "")) + + reference_model = normalize_name(reference_sheet.get("model", "")) + reference_device = normalize_name(reference_sheet.get("base_device", "")) + # reference_os = normalize_name(reference_sheet.get("os_version", "")) + + if not primary_model: + print("Warning: Primary sheet {} has no model info, for {primary_model} ") + return False + if not reference_model: + print("Warning: Reference sheet {} has no model info, for {reference_model}") + return False # Model must match exactly if primary_model != reference_model: @@ -69,26 +62,12 @@ def is_matching_dataset(primary_sheet, reference_sheet): if primary_device != reference_device: return False - # If we get here, model and device base match, so it's a valid comparison - # even if OS versions differ return True def analyze_latency_stability( # noqa: C901 primary_file, reference_file=None, output_dir="stability_analysis_results" ): - """ - Analyze latency stability metrics from benchmark data in Excel files. - - Parameters: - ----------- - primary_file : str - Path to the Excel file containing primary (private) benchmark data - reference_file : str, optional - Path to the Excel file containing reference (public) benchmark data - output_dir : str - Directory to save output files - """ print(f"Analyzing latency stability from primary file: {primary_file}") if reference_file: print(f"Using reference file for comparison: {reference_file}") @@ -100,15 +79,28 @@ def analyze_latency_stability( # noqa: C901 # Load primary datasets print_section_header("LOADING PRIMARY DATASETS (Private)") primary_datasets = {} - primary_xls = pd.ExcelFile(primary_file) + documents = read_excel_with_json_header(primary_file) + + for document in documents: + sheetName = document.get("sheetName", None) + df = document.get("df", None) + config = document.get("groupInfo", None) + print(f"Loading dataset: {sheetName} with config: {config} ") + + if df is None or df.empty: + print(f"Skipping sheet {sheetName} because it has no df data") + continue + + if not config or not sheetName: + print( + f" Skipping document: Missing required data groupInfo:{config} sheetName:{sheetName}" + ) + continue - for sheet in primary_xls.sheet_names: - print(f"Loading dataset: {sheet}") - df = pd.read_excel(primary_xls, sheet_name=sheet) - model, device = parse_model_device(sheet) + model, full_device, base_device, os_version = parse_model_device_config(config) # Check if required columns exist - required_cols = ["InferenceTime", "Date"] + required_cols = ["avg_inference_latency(ms)", "metadata_info.timestamp"] if "trimmean_inference_latency(ms)" in df.columns: trimmed_col = "trimmean_inference_latency(ms)" required_cols.append(trimmed_col) @@ -123,36 +115,54 @@ def analyze_latency_stability( # noqa: C901 # Skip sheets without required columns if not all(col in df.columns for col in required_cols): - print(f" Skipping {sheet}: Missing required columns") + print(f" Skipping {sheetName}: Missing required columns") continue # Convert Date to datetime - df["Date"] = pd.to_datetime(df["Date"]) + df["Date"] = pd.to_datetime(df["metadata_info.timestamp"]) # Calculate stability metrics - metrics = calculate_stability_metrics(df, "InferenceTime", trimmed_col, tps_col) + metrics = calculate_stability_metrics( + df, "avg_inference_latency(ms)", trimmed_col, tps_col + ) - primary_datasets[sheet] = { + primary_datasets[sheetName] = { "df": df, "metrics": metrics, "model": model, - "device": device, - "sheet_name": sheet, + "full_device": full_device, + "base_device": base_device, + "os_version": os_version, + "sheet_name": sheetName, } # Load reference datasets if provided reference_datasets = {} if reference_file: print_section_header("LOADING REFERENCE DATASETS (Public)") - reference_xls = pd.ExcelFile(reference_file) + documents = read_excel_with_json_header(reference_file) + + for document in documents: + sheetName = document.get("sheetName", None) + df = document.get("df", None) + config = document.get("groupInfo", None) + print(f"Loading dataset: {sheetName} with config:{config}") + if df is None or df.empty: + print(f"Skipping sheet {sheetName} because it has no df data") + continue + + if not config or not sheetName: + print( + f" Skipping document: Missing required data groupInfo:{config} sheetName:{sheetName}" + ) + continue - for sheet in reference_xls.sheet_names: - print(f"Loading reference dataset: {sheet}") - df = pd.read_excel(reference_xls, sheet_name=sheet) - model, device = parse_model_device(sheet) + model, full_device, base_device, os_version = parse_model_device_config( + config + ) # Check if required columns exist - required_cols = ["InferenceTime", "Date"] + required_cols = ["avg_inference_latency(ms)", "metadata_info.timestamp"] if "trimmean_inference_latency(ms)" in df.columns: trimmed_col = "trimmean_inference_latency(ms)" required_cols.append(trimmed_col) @@ -167,23 +177,27 @@ def analyze_latency_stability( # noqa: C901 # Skip sheets without required columns if not all(col in df.columns for col in required_cols): - print(f" Skipping reference {sheet}: Missing required columns") + print( + f" Skipping reference {sheetName}: Missing required columns{required_cols}" + ) continue # Convert Date to datetime - df["Date"] = pd.to_datetime(df["Date"]) + df["Date"] = pd.to_datetime(df["metadata_info.timestamp"]) # Calculate stability metrics metrics = calculate_stability_metrics( - df, "InferenceTime", trimmed_col, tps_col + df, "avg_inference_latency(ms)", trimmed_col, tps_col ) - reference_datasets[sheet] = { + reference_datasets[sheetName] = { "df": df, "metrics": metrics, "model": model, - "device": device, - "sheet_name": sheet, + "full_device": full_device, + "sheet_name": sheetName, + "base_device": base_device, + "os_version": os_version, } # Process primary datasets @@ -193,7 +207,7 @@ def analyze_latency_stability( # noqa: C901 generate_dataset_report( sheet, info["model"], - info["device"], + info["full_device"], "Primary", info["df"], info["metrics"], @@ -212,7 +226,7 @@ def analyze_latency_stability( # noqa: C901 generate_dataset_report( sheet, info["model"], - info["device"], + info["full_device"], "Reference", info["df"], info["metrics"], @@ -232,7 +246,7 @@ def analyze_latency_stability( # noqa: C901 found_match = False for ref_sheet, ref_info in reference_datasets.items(): - if is_matching_dataset(primary_sheet, ref_sheet): + if is_matching_dataset(primary_info, ref_info): # Found a match print( f"Matched: {primary_sheet} (Private) with {ref_sheet} (Public)" @@ -240,11 +254,8 @@ def analyze_latency_stability( # noqa: C901 generate_comparison_report( primary_sheet, ref_sheet, - primary_info["model"], - primary_info["device"], - ref_info["device"], - primary_info["metrics"], - ref_info["metrics"], + primary_info, + ref_info, output_dir, ) found_match = True @@ -252,7 +263,9 @@ def analyze_latency_stability( # noqa: C901 break if not found_match: - print(f"Warning: No matching reference dataset for {primary_sheet}") + print( + f"Warning: No matching reference dataset for {primary_sheet} with config: {primary_info['model']}{primary_info['full_device']} " + ) if not matches_found: print("No matching datasets found between primary and reference files.") @@ -620,7 +633,12 @@ def generate_time_series_plot(dataset_name, df, output_dir, dataset_type): df_sorted = df.sort_values("Date") # Plot raw latency - plt.plot(df_sorted["Date"], df_sorted["InferenceTime"], "b-", label="Raw Latency") + plt.plot( + df_sorted["Date"], + df_sorted["avg_inference_latency(ms)"], + "b-", + label="Raw Latency", + ) # Plot trimmed latency if available if "trimmean_inference_latency(ms)" in df_sorted.columns: @@ -634,7 +652,9 @@ def generate_time_series_plot(dataset_name, df, output_dir, dataset_type): # Add rolling mean window = min(5, len(df_sorted)) if window > 1: - rolling_mean = df_sorted["InferenceTime"].rolling(window=window).mean() + rolling_mean = ( + df_sorted["avg_inference_latency(ms)"].rolling(window=window).mean() + ) plt.plot( df_sorted["Date"], rolling_mean, "r--", label=f"{window}-point Rolling Mean" ) @@ -658,11 +678,8 @@ def generate_time_series_plot(dataset_name, df, output_dir, dataset_type): def generate_comparison_report( # noqa: C901 primary_sheet, reference_sheet, - model, - primary_device, - reference_device, - primary_metrics, - reference_metrics, + primary_info, + reference_info, output_dir, ): """Generate a comparison report between primary and reference datasets""" @@ -671,6 +688,12 @@ def generate_comparison_report( # noqa: C901 # Create a string buffer to hold the report content report_content = [] + model = (primary_info["model"],) + primary_device = (primary_info["full_device"],) + reference_device = reference_info["full_device"] + primary_metrics = primary_info["metrics"] + reference_metrics = reference_info["metrics"] + # Header report_content.append("Private vs Public Stability Comparison") report_content.append("=" * 80) @@ -971,8 +994,10 @@ def generate_comparison_report( # noqa: C901 ) # Note about OS version difference if applicable - _, primary_device_base, primary_os = extract_model_device_os(primary_sheet) - _, reference_device_base, reference_os = extract_model_device_os(reference_sheet) + primary_device_base = primary_info.get("base_device", "") + primary_os = primary_info.get("os_version", "") + reference_device_base = reference_info.get("base_device", "") + reference_os = reference_info.get("os_version", "") if primary_os != reference_os and primary_os and reference_os: report_content.append("") @@ -1030,7 +1055,7 @@ def generate_intra_primary_summary(primary_datasets, output_dir): # noqa: C901 { "Sheet": sheet_name, "Model": info["model"], - "Device": info["device"], + "Device": info["full_device"], "Mean Latency (ms)": info["metrics"]["mean_raw_latency"], "CV (%)": info["metrics"]["cv_raw_latency"], "Stability Score": info["metrics"]["stability_score"], @@ -1103,8 +1128,8 @@ def generate_intra_primary_summary(primary_datasets, output_dir): # noqa: C901 # Device-based comparison # First, extract base device names for grouping device_base_map = {} - for sheet_name in primary_datasets: - _, device_base, _ = extract_model_device_os(sheet_name) + for sheet_name, info in primary_datasets.items(): + device_base = info.get("base_device", "") device_base_map[sheet_name] = device_base # Add base device to DataFrame @@ -1138,8 +1163,8 @@ def generate_intra_primary_summary(primary_datasets, output_dir): # noqa: C901 # OS version comparison if multiple OS versions exist os_versions = {} - for sheet_name in primary_datasets: - _, _, os_version = extract_model_device_os(sheet_name) + for sheet_name, info in primary_datasets.items(): + os_version = info.get("os_version", "") if os_version: # Only include if OS version was extracted os_versions[sheet_name] = os_version @@ -1254,9 +1279,13 @@ def generate_summary_report( # noqa: C901 # Primary datasets summary primary_data = [] for sheet_name, info in primary_datasets.items(): - model, device_base, os_version = extract_model_device_os(sheet_name) + model, device_base, os_version = ( + info.get("model", ""), + info.get("base_device", ""), + info.get("os_version", ""), + ) device_display = ( - f"{device_base} ({os_version})" if os_version else info["device"] + f"{device_base}({os_version})" if os_version else info["device"] ) primary_data.append( @@ -1287,9 +1316,13 @@ def generate_summary_report( # noqa: C901 if reference_datasets: reference_data = [] for sheet_name, info in reference_datasets.items(): - model, device_base, os_version = extract_model_device_os(sheet_name) + model, device_base, os_version = ( + info.get("model", ""), + info.get("base_device", ""), + info.get("os_version", ""), + ) device_display = ( - f"{device_base} ({os_version})" if os_version else info["device"] + f"{device_base}({os_version})" if os_version else info["device"] ) reference_data.append( @@ -1322,29 +1355,31 @@ def generate_summary_report( # noqa: C901 # Comparison summary for matching datasets comparison_data = [] - for primary_sheet, primary_info in primary_datasets.items(): - for ref_sheet, ref_info in reference_datasets.items(): - if is_matching_dataset(primary_sheet, ref_sheet): + for _, primary_info in primary_datasets.items(): + for _, ref_info in reference_datasets.items(): + if is_matching_dataset(primary_info, ref_info): primary_metrics = primary_info["metrics"] reference_metrics = ref_info["metrics"] # Extract model and device info for display - model, primary_device_base, primary_os = extract_model_device_os( - primary_sheet - ) - _, reference_device_base, reference_os = extract_model_device_os( - ref_sheet + model, primary_device_base, primary_os = ( + primary_info.get("model", ""), + primary_info.get("base_device", ""), + primary_info.get("os_version", ""), ) + reference_device_base, reference_os = ref_info.get( + "base_device", "" + ), ref_info.get("os_version", "") primary_device_display = ( f"{primary_device_base} ({primary_os})" if primary_os - else primary_info["device"] + else primary_info["full_device"] ) reference_device_display = ( f"{reference_device_base} ({reference_os})" if reference_os - else ref_info["device"] + else ref_info["full_device"] ) comparison_data.append( @@ -1424,15 +1459,15 @@ def generate_summary_report( # noqa: C901 # OS version insights if available os_versions = {} - for sheet_name in primary_datasets: - _, _, os_version = extract_model_device_os(sheet_name) + for sheet_name, info in primary_datasets.items(): + os_version = info.get("os_version", "") if os_version: os_versions[sheet_name] = os_version if os_versions and len(set(os_versions.values())) > 1: # Add OS version to primary DataFrame primary_df["OS Version"] = primary_df["Dataset"].map( - lambda x: extract_model_device_os(x)[2] + lambda x: primary_datasets[x].get("os_version", np.nan) ) # Remove rows with no OS version @@ -1498,11 +1533,11 @@ def main(): description="Analyze ML model latency stability from benchmark data." ) parser.add_argument( - "primary_file", + "--primary-file", help="Path to Excel file containing primary (private) benchmark data", ) parser.add_argument( - "--reference_file", + "--reference-file", help="Path to Excel file containing reference (public) benchmark data for comparison", default=None, ) diff --git a/.ci/scripts/benchmark_tooling/common.py b/.ci/scripts/benchmark_tooling/common.py new file mode 100644 index 00000000000..521e9f3b3ce --- /dev/null +++ b/.ci/scripts/benchmark_tooling/common.py @@ -0,0 +1,50 @@ +import json +import os +from typing import Any, Dict, List + +import pandas as pd + + +def read_excel_with_json_header(path: str) -> List[Dict[str, Any]]: + # Read all sheets into a dict of DataFrames, without altering + all_sheets = pd.read_excel(path, sheet_name=None, header=None, engine="openpyxl") + + results = [] + for sheet, df in all_sheets.items(): + # Extract JSON string from A1 (row 0, col 0) + json_str = df.iat[0, 0] + meta = json.loads(json_str) if isinstance(json_str, str) else {} + + # The actual data starts from the next row; treat row 1 as header + df_data = pd.read_excel(path, sheet_name=sheet, skiprows=1, engine="openpyxl") + results.append({"groupInfo": meta, "df": df_data, "sheetName": sheet}) + print(f"successfully fetched {len(results)} sheets from {path}") + return results + + +def read_all_csv_with_metadata(folder_path: str) -> List[Dict[str, Any]]: + results = [] # {filename: {"meta": dict, "df": DataFrame}} + for fname in os.listdir(folder_path): + if not fname.lower().endswith(".csv"): + continue + path = os.path.join(folder_path, fname) + with open(path, "r", encoding="utf-8") as f: + first_line = f.readline().strip() + try: + meta = json.loads(first_line) + except json.JSONDecodeError: + meta = {} + df = pd.read_csv(path, skiprows=1) + results.append({"groupInfo": meta, "df": df, "sheetName": fname}) + print(f"successfully fetched {len(results)} sheets from {folder_path}") + return results + + +import logging + +logging.basicConfig(level=logging.INFO) + +# For Excel files (assuming the Excel file is in the current directory) +file_path = "./private.xlsx" +res = read_excel_with_json_header(file_path) +logging.info(res) diff --git a/.ci/scripts/benchmark_tooling/get_benchmark_analysis_data.py b/.ci/scripts/benchmark_tooling/get_benchmark_analysis_data.py new file mode 100644 index 00000000000..d2d0b15d063 --- /dev/null +++ b/.ci/scripts/benchmark_tooling/get_benchmark_analysis_data.py @@ -0,0 +1,763 @@ +""" +ExecutorchBenchmark Analysis Data Retrieval + +This module provides tools for fetching, processing, and analyzing benchmark data +from the HUD Open API for ExecutorchBenchmark. It supports filtering data by (private) device pool names, +backends, and models, exporting results in various formats (JSON, DataFrame, Excel, CSV), +and customizing data retrieval parameters. +""" + +import argparse +import json +import logging +import os +import re +from copy import deepcopy +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union + +import pandas as pd +import requests +from yaspin import yaspin + +logging.basicConfig(level=logging.INFO) + +# add here just for the records +VALID_PRIVATE_DEVICE_POOLS_MAPPINGS = { + "apple_iphone_15_private": [ + ("Apple iPhone 15 Pro (private)", "iOS 18.4.1"), + ("Apple iPhone 15 (private)", "iOS 18.0"), + ("Apple iPhone 15 Plus (private)", "iOS 17.4.1"), + ], + "samsung_s22_private": [ + ("Samsung Galaxy S22 Ultra 5G (private)", "Android 14"), + ("Samsung Galaxy S22 5G (private)", "Android 13"), + ], +} + +VALID_PRIVATE_DEVICE_POOLS_NAMES = list(VALID_PRIVATE_DEVICE_POOLS_MAPPINGS.keys()) + + +class OutputType(Enum): + """ + Enumeration of supported output formats for benchmark data. + + Values: + EXCEL: Export data to Excel spreadsheets + PRINT: Print data to console (default) + CSV: Export data to CSV files + JSON: Export data to JSON files + DF: Return data as pandas DataFrames + """ + + EXCEL = "excel" + PRINT = "print" + CSV = "csv" + JSON = "json" + DF = "df" + + +@dataclass +class BenchmarkQueryGroupDataParams: + """ + Parameters for querying benchmark data from HUD API. + + Attributes: + repo: Repository name (e.g., "pytorch/executorch") + benchmark_name: Name of the benchmark (e.g., "ExecuTorch") + start_time: ISO8601 formatted start time + end_time: ISO8601 formatted end time + group_table_by_fields: Fields to group tables by + group_row_by_fields: Fields to group rows by + """ + + repo: str + benchmark_name: str + start_time: str + end_time: str + group_table_by_fields: list + group_row_by_fields: list + + +@dataclass +class MatchingGroupResult: + """ + Container for benchmark results grouped by category. + + Attributes: + category: Category name (e.g., 'private', 'public') + data: List of benchmark data for this category + """ + + category: str + data: list + + +@dataclass +class BenchmarkFilters: + models: list + backends: list + devicePoolNames: list + + +BASE_URLS = { + "local": "http://localhost:3000", + "prod": "https://hud.pytorch.org", +} + + +def validate_iso8601_no_ms(value: str): + """ + Validate that a string is in ISO8601 format without milliseconds. + Args: + value: String to validate (format: YYYY-MM-DDTHH:MM:SS) + """ + try: + return datetime.strptime(value, "%Y-%m-%dT%H:%M:%S").strftime( + "%Y-%m-%dT%H:%M:%S" + ) + except ValueError: + raise argparse.ArgumentTypeError( + f"Invalid datetime format for '{value}'. Expected: YYYY-MM-DDTHH:MM:SS" + ) + + +class ExecutorchBenchmarkFetcher: + """ + Fetch and process benchmark data from HUD API for ExecutorchBenchmark. + + This class provides methods to: + 1. Fetch all benchmark data for a specified time range + 2. Get all private device info within the time range + 3. Filter the private device data if filter is provided + 4. Then use the filtered private device data to find matched the public device data using [model, backend, device, arch] + 3. Export results in various formats (JSON, DataFrame, Excel, CSV) + + Usage: + fetcher = ExecutorchBenchmarkFetcher() + fetcher.run(start_time, end_time) + fetcher.output_data(OutputType.EXCEL, output_dir="./results") + """ + + def __init__( + self, + env: str = "prod", + disable_logging: bool = False, + group_table_fields=None, + group_row_fields=None, + ): + """ + Initialize the ExecutorchBenchmarkFetcher. + + Args: + env: Environment to use ('local' or 'prod') + disable_logging: Whether to suppress log output + group_table_fields: Custom fields to group tables by (defaults to device, backend, arch, model) + group_row_fields: Custom fields to group rows by (defaults to workflow_id, job_id, granularity_bucket) + """ + self.env = env + self.base_url = self._get_base_url() + self.query_group_table_by_fields = ( + group_table_fields + if group_table_fields + else ["model", "backend", "device", "arch"] + ) + self.query_group_row_by_fields = ( + group_row_fields + if group_row_fields + else ["workflow_id", "job_id", "metadata_info.timestamp"] + ) + self.data = None + self.disable_logging = disable_logging + self.matching_groups: Dict[str, MatchingGroupResult] = {} + + def run( + self, + start_time: str, + end_time: str, + filters: Optional[BenchmarkFilters] = None, + ) -> None: + # reset group & raw data for new run + self.matching_groups = {} + self.data = None + + data = self._fetch_execu_torch_data(start_time, end_time) + if data is None: + logging.warning("no data fetched from the HUD API") + return None + self._proces_raw_data(data) + self._process_private_public_data(filters) + + def _filter_out_failure_only( + self, data_list: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Clean data by removing rows that only contain FAILURE_REPORT metrics. + + Args: + data_list: List of benchmark data dictionaries + + Returns: + Filtered list with rows containing only FAILURE_REPORT removed + """ + ONLY = {"workflow_id", "metadata_info.timestamp", "job_id", "FAILURE_REPORT"} + for item in data_list: + filtered_rows = [ + row + for row in item.get("rows", []) + # Keep row only if it has additional fields beyond ONLY + if not set(row.keys()).issubset(ONLY) + ] + item["rows"] = filtered_rows + return [item for item in data_list if item.get("rows")] + + def _filter_public_result(self, private_list, all_public): + # find intersection betwen private and public tables. + common = list( + set([item["table_name"] for item in private_list]) + & set([item["table_name"] for item in all_public]) + ) + + if not self.disable_logging: + logging.info( + f"Found {len(common)} table names existed in both private and public, use it to filter public tables:" + ) + logging.info(json.dumps(common, indent=1)) + filtered_public = [item for item in all_public if item["table_name"] in common] + return filtered_public + + def get_result(self) -> Dict[str, List[Dict[str, Any]]]: + """ + Get a deep copy of the benchmark results. + + Returns: + Dictionary containing benchmark results grouped by category + """ + return deepcopy(self.to_dict()) + + def to_excel(self, output_dir: str = ".") -> None: + """ + Export benchmark results to Excel files. + Creates two Excel files: + - res_private.xlsx: Results for private devices + - res_public.xlsx: Results for public devices + Each file contains multiple sheets, one per benchmark configuration for private and public. + Args: + output_dir: Directory to save Excel files + """ + for item in self.matching_groups.values(): + self._write_multi_sheet_excel(item.data, output_dir, item.category) + + def _write_multi_sheet_excel(self, data_list, output_dir, file_name): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + logging.info(f"Created output directory: {output_dir}") + else: + logging.info(f"Using existing output directory: {output_dir}") + file = os.path.join(output_dir, f"{file_name}.xlsx") + with pd.ExcelWriter(file, engine="xlsxwriter") as writer: + workbook = writer.book + for idx, entry in enumerate(data_list): + sheet_name = f"table{idx+1}" + df = pd.DataFrame(entry.get("rows", [])) + + # Encode metadata as compact JSON string + meta = entry.get("groupInfo", {}) + json_str = json.dumps(meta, separators=(",", ":")) + + worksheet = workbook.add_worksheet(sheet_name) + writer.sheets[sheet_name] = worksheet + + # Write JSON into A1 + worksheet.write_string(0, 0, json_str) + + logging.info( + f"Wrting excel sheet to file {file} with sheet name {sheet_name} for {entry['table_name']}" + ) + # Write DataFrame starting at row 2 (index 1) + df.to_excel(writer, sheet_name=sheet_name, startrow=1, index=False) + + def output_data( + self, output_type: OutputType = OutputType.PRINT, output_dir: str = "." + ) -> Any: + """ + Generate output in the specified format. + + Supports multiple output formats: + - PRINT: Print results to console + - JSON: Export to JSON files + - DF: Return as pandas DataFrames + - EXCEL: Export to Excel files + - CSV: Export to CSV files + + Args: + output_type: Format to output the data in + output_dir: Directory to save output files (for file-based formats) + + Returns: + Benchmark results in the specified format + """ + logging.info( + f"Generating output with type {output_type}: {[self.matching_groups.keys()]}" + ) + + o_type = self._to_output_type(output_type) + if o_type == OutputType.PRINT: + logging.info("\n ========= Generate print output ========= \n") + logging.info(json.dumps(self.get_result(), indent=2)) + elif o_type == OutputType.JSON: + logging.info("\n ========= Generate json output ========= \n") + file_path = self.to_json(output_dir) + logging.info(f"success, please check {file_path}") + elif o_type == OutputType.DF: + logging.info("\n ========= Generate dataframe output ========= \n") + res = self.to_df() + logging.info(res) + return res + elif o_type == OutputType.EXCEL: + logging.info("\n ========= Generate excel output ========= \n") + self.to_excel(output_dir) + elif o_type == OutputType.CSV: + logging.info("\n ========= Generate csv output ========= \n") + self.to_csv(output_dir) + return self.get_result() + + def _to_output_type(self, output_type: Any) -> OutputType: + if isinstance(output_type, str): + try: + return OutputType(output_type.lower()) + except ValueError: + logging.warning( + f"Invalid output type string: {output_type}. Defaulting to PRINT" + ) + return OutputType.JSON + elif isinstance(output_type, OutputType): + return output_type + logging.warning(f"Invalid output type: {output_type}. Defaulting to JSON") + return OutputType.JSON + + def to_json(self, output_dir: str = ".") -> Any: + """ + Export benchmark results to a JSON file. + + Args: + output_dir: Directory to save the JSON file + + Returns: + Path to the generated JSON file + """ + data = self.get_result() + return self.generate_json_file(data, "benchmark_results", output_dir) + + def generate_json_file(self, data, file_name, output_dir: str = "."): + """ + Generate a JSON file from the provided data. + + Args: + data: Data to write to the JSON file + file_name: Name for the JSON file (without extension) + output_dir: Directory to save the JSON file + + Returns: + Path to the generated JSON file + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + logging.info(f"Created output directory: {output_dir}") + else: + logging.info(f"Using existing output directory: {output_dir}") + path = os.path.join(output_dir, file_name + ".json") + with open(path, "w") as f: + json.dump(data, f, indent=2) + return path + + def to_dict(self) -> Dict[str, List[Dict[str, Any]]]: + """ + Convert benchmark results to a dictionary. + + Returns: + Dictionary with categories as keys and benchmark data as values + """ + result = {} + for item in self.matching_groups.values(): + result[item.category] = item.data + return result + + def to_df(self) -> Dict[str, List[Dict[str, Union[Dict[str, Any], pd.DataFrame]]]]: + """ + Convert benchmark results to pandas DataFrames. + + Creates a dictionary with categories as keys and lists of DataFrames as values. + Each DataFrame represents one benchmark configuration. + + Returns: + Dictionary mapping categories ['private','public'] to lists of DataFrames "df" with metadata 'groupInfo'. + + """ + result = {} + for item in self.matching_groups.values(): + result[item.category] = [ + { + "groupInfo": item.get("groupInfo", {}), + "df": pd.DataFrame(item.get("rows", [])), + } + for item in item.data + ] + return result + + def to_csv(self, output_dir: str = ".") -> None: + """ + Export benchmark results to CSV files. + + Creates two CSV folders and one json file: + - private/: Results for private devices + - public/: Results for public devices + - benchmark_name_mappings.json: json dict which maps the generated csv file_name to + + Each file contains multiple CSV files, one per benchmark configuration for private and public. + + Args: + output_dir: Directory to save CSV files + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + logging.info(f"Created output directory: {output_dir}") + else: + logging.info(f"Using existing output directory: {output_dir}") + + for item in self.matching_groups.values(): + path = os.path.join(output_dir, item.category) + self._write_multiple_csv_files(item.data, path) + + def _write_multiple_csv_files( + self, data_list: List[Dict[str, Any]], output_dir: str, prefix: str = "" + ) -> None: + """ + Write multiple benchmark results to CSV files. + + Creates a CSV file for each benchmark configuration, with metadata + as a JSON string in the first row and data in subsequent rows. + + Args: + data_list: List of benchmark result dictionaries + output_dir: Directory to save CSV files + prefix: Optional prefix for CSV filenames + """ + os.makedirs(output_dir, exist_ok=True) + for idx, entry in enumerate(data_list): + filename = f"{prefix}_table{idx+1}.csv" if prefix else f"table{idx+1}.csv" + file_path = os.path.join(output_dir, filename) + + # Prepare DataFrame + df = pd.DataFrame(entry.get("rows", [])) + + # Prepare metadata JSON (e.g. groupInfo) + meta = entry.get("groupInfo", {}) + json_str = json.dumps(meta, separators=(",", ":")) + + logging.info(f"Wrting csv file to {file_path}") + + # Write metadata and data + with open(file_path, "w", encoding="utf-8", newline="") as f: + f.write(json_str + "\n") # First row: JSON metadata + df.to_csv(f, index=False) # Remaining rows: DataFrame rows + + def _get_base_url(self) -> str: + """ + Get the base URL for API requests based on environment. + + Returns: + Base URL string for the configured environment + """ + return BASE_URLS[self.env] + + def get_all_private_devices(self) -> Tuple[List[Any], List[Any]]: + """ + Print all devices found in the data. + Separates results by category and displays counts. + This is useful for debugging and understanding what data is available. + """ + if not self.data: + logging.info("No data found, please call get_data() first") + return ([], []) + + all_private = { + (group.get("device", ""), group.get("arch", "")) + for item in self.data + if (group := item.get("groupInfo", {})).get("aws_type") == "private" + } + iphone_set = {pair for pair in all_private if "iphone" in pair[0].lower()} + samsung_set = {pair for pair in all_private if "samsung" in pair[0].lower()} + + # logging + logging.info( + f"Found private {len(iphone_set)} iphone devices: {list(iphone_set)}" + ) + logging.info( + f"Found private {len(samsung_set)} samsung devices: {list(samsung_set)}" + ) + return (list(iphone_set), list(samsung_set)) + + def _generate_table_name( + self, group_info: Dict[str, Any], fields: List[str] + ) -> str: + """ + Generate a table name from group info fields. + + Creates a normalized string by joining specified fields from group info. + + Args: + group_info: Dictionary containing group information + fields: List of field names to include in the table name + + Returns: + Normalized table name string + """ + name = "-".join( + self.normalize_string(group_info[k]) + for k in fields + if k in group_info and group_info[k] + ) + + return name + + def _proces_raw_data(self, input_data: List[Dict[str, Any]]): + """ + Process raw benchmark data. + """ + logging.info(f"fetched {len(input_data)} data from HUD") + data = self._clean_data(input_data) + + for item in data: + org_group = item.get("groupInfo", {}) + if org_group.get("device", "").find("private") != -1: + item["groupInfo"]["aws_type"] = "private" + else: + item["groupInfo"]["aws_type"] = "public" + # Add full name joined by the group key fields + item["table_name"] = self._generate_table_name( + org_group, self.query_group_table_by_fields + ) + self.data = deepcopy(data) + + def _process_private_public_data(self, filters: Optional[BenchmarkFilters]): + """ + Process raw benchmark data. + """ + if not self.data: + logging.info("No data found, please call get_data() first") + return + + # + private_list = sorted( + ( + item + for item in self.data + if item.get("groupInfo", {}).get("aws_type") == "private" + ), + key=lambda x: x["table_name"], + ) + + if filters: + logging.info(f"Found {len(private_list)} private tables before filtering") + private_list = self.filter_private_results(private_list, filters) + else: + logging.info("filters is None, using all private results") + + all_public = sorted( + ( + item + for item in self.data + if item.get("groupInfo", {}).get("aws_type") == "public" + ), + key=lambda x: x["table_name"], + ) + public_list = self._filter_public_result(private_list, all_public) + + logging.info( + f"Found {len(private_list)} private tables, {[item['table_name'] for item in private_list]}" + ) + logging.info( + f"Found assoicated {len(public_list)} public tables, {json.dumps([item['table_name'] for item in public_list],indent=2)}" + ) + + self.matching_groups["private"] = MatchingGroupResult( + category="private", data=private_list + ) + self.matching_groups["public"] = MatchingGroupResult( + category="public", data=public_list + ) + + def _clean_data(self, data_list): + # filter data with arch equal exactly "",ios and android, this normally + # indicates it's job-level falure indicator + removed_gen_arch = [ + item + for item in data_list + if (arch := item.get("groupInfo", {}).get("arch")) is not None + and arch.lower() not in ("ios", "android") + ] + data = self._filter_out_failure_only(removed_gen_arch) + return data + + def _fetch_execu_torch_data( + self, start_time: str, end_time: str + ) -> Optional[List[Dict[str, Any]]]: + url = f"{self.base_url}/api/benchmark/group_data" + params_object = BenchmarkQueryGroupDataParams( + repo="pytorch/executorch", + benchmark_name="ExecuTorch", + start_time=start_time, + end_time=end_time, + group_table_by_fields=self.query_group_table_by_fields, + group_row_by_fields=self.query_group_row_by_fields, + ) + params = {k: v for k, v in params_object.__dict__.items() if v is not None} + with yaspin(text="Waiting for response", color="cyan") as spinner: + response = requests.get(url, params=params) + if response.status_code == 200: + spinner.ok("V") + return response.json() + else: + logging.info(f"Failed to fetch benchmark data ({response.status_code})") + logging.info(response.text) + spinner.fail("x") + return None + + def normalize_string(self, s: str) -> str: + s = s.lower().strip() + s = s.replace("+", "plus") + s = s.replace("-", "_") + s = s.replace(" ", "_") + s = re.sub(r"[^\w\-\.\(\)]", "_", s) + s = re.sub(r"_{2,}", "_", s) + s = s.replace("_(", "(").replace("(_", "(") + s = s.replace(")_", ")").replace("_)", ")") + s = s.replace("(private)", "") + return s + + def filter_private_results( + self, all_privates: List[Dict[str, Any]], filters: BenchmarkFilters + ): + """ + dynamically filter private device data based on filters, if any. + fetch all private devices within the time range, and then filter based on filter parameters + such as device_pool, backends, and models. + """ + private_devices = self.get_all_private_devices() + + device_pool = filters.devicePoolNames or set() + backends = filters.backends or set() + models = filters.models or set() + + if not backends and not device_pool and not models: + logging.info("No filters provided, using all private results") + return all_privates + + device_ios_match = set() + # hardcoded since we only have 2 device pools, each for iphone and samsung + if "apple_iphone_15_private" in device_pool: + device_ios_match.update( + private_devices[0] + ) # assumed to be list of (device, arch) + if "samsung_s22_private" in device_pool: + device_ios_match.update(private_devices[1]) + logging.info( + f"Applying filter: backends={backends}, devices={device_pool}, models={models}, pair_filter={bool(device_ios_match)}" + ) + results = [] + for item in all_privates: + info = item.get("groupInfo", {}) + if backends and info.get("backend") not in backends: + continue + + if device_ios_match: + # must match both device and arch in a record, otherwise skip + pair = (info.get("device", ""), info.get("arch", "")) + if pair not in device_ios_match: + continue + if models and info.get("model", "") not in models: + continue + results.append(item) + + logging.info( + f"Filtered from private data {len(all_privates)} → {len(results)} results" + ) + if not results: + logging.info("No results matched the filters. Something is wrong.") + return results + + +def argparsers(): + parser = argparse.ArgumentParser(description="Benchmark Analysis Runner") + + # Required common args + parser.add_argument( + "--startTime", + type=validate_iso8601_no_ms, + required=True, + help="Start time, ISO format (e.g. 2025-06-01T00:00:00)", + ) + parser.add_argument( + "--endTime", + type=validate_iso8601_no_ms, + required=True, + help="End time, ISO format (e.g. 2025-06-06T00:00:00)", + ) + parser.add_argument( + "--env", choices=["local", "prod"], default="prod", help="Environment" + ) + + parser.add_argument( + "--no-silent", + action="store_false", + dest="silent", + default=True, + help="Allow output (disable silent mode)", + ) + + # Options for generate_data + parser.add_argument( + "--outputType", + choices=["json", "df", "csv", "print", "excel"], + default="print", + help="Output format (only for generate_data)", + ) + + parser.add_argument( + "--outputDir", default=".", help="Output directory, default is ." + ) + parser.add_argument( + "--backends", + nargs="+", + help="Filter results by one or more backend full name(e.g. --backends qlora mv3) (OR logic within backends scope, AND logic with other filter type)", + ) + parser.add_argument( + "--device-pools", + nargs="+", # allow one or more values + choices=VALID_PRIVATE_DEVICE_POOLS_NAMES, + help="List of devices to include [apple_iphone_15_private, samsung_s22_private, you can include both] (OR logic within private-device-pools scope, AND logic with other filter type)", + ) + parser.add_argument( + "--models", + nargs="+", + help="Filter by one or more models (e.g. --backend 'meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8' 'mv3') (OR logic withn models scope, AND logic with other filter type)", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = argparsers() + fetcher = ExecutorchBenchmarkFetcher(args.env, args.silent) + result = fetcher.run( + args.startTime, + args.endTime, + filters=BenchmarkFilters( + models=args.models, + backends=args.backends, + devicePoolNames=args.device_pools, + ), + ) + fetcher.output_data(args.outputType, args.outputDir) diff --git a/.ci/scripts/benchmark_tooling/requirements.txt b/.ci/scripts/benchmark_tooling/requirements.txt new file mode 100644 index 00000000000..3a2d69c0676 --- /dev/null +++ b/.ci/scripts/benchmark_tooling/requirements.txt @@ -0,0 +1,7 @@ +requests>=2.32.3 +xlsxwriter>=3.2.3 +pandas>=2.3.0 +yaspin>=3.1.0 +tabulate +matplotlib +openpyxl diff --git a/.ci/scripts/test_model.sh b/.ci/scripts/test_model.sh index 38a354eddf0..4f8dc7a30e5 100755 --- a/.ci/scripts/test_model.sh +++ b/.ci/scripts/test_model.sh @@ -188,6 +188,14 @@ test_model_with_qnn() { EXPORT_SCRIPT=edsr # Additional deps for edsr pip install piq + elif [[ "${MODEL_NAME}" == "albert" ]]; then + EXPORT_SCRIPT=albert + elif [[ "${MODEL_NAME}" == "bert" ]]; then + EXPORT_SCRIPT=bert + elif [[ "${MODEL_NAME}" == "distilbert" ]]; then + EXPORT_SCRIPT=distilbert + elif [[ "${MODEL_NAME}" == "eurobert" ]]; then + EXPORT_SCRIPT=eurobert else echo "Unsupported model $MODEL_NAME" exit 1 @@ -197,7 +205,25 @@ test_model_with_qnn() { # TODO(guangyang): Make QNN chipset matches the target device QNN_CHIPSET=SM8450 - "${PYTHON_EXECUTABLE}" -m examples.qualcomm.scripts.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --ci --compile_only $EXTRA_FLAGS + SCRIPT_FOLDER="" + case "${MODEL_NAME}" in + "dl3"|"mv3"|"mv2"|"ic4"|"ic3"|"vit"|"mb"|"w2l") + SCRIPT_FOLDER=scripts + ;; + "albert"|"bert"|"distilbert") + pip install evaluate + SCRIPT_FOLDER=oss_scripts + # Bert models running in 16bit will encounter op validation fail on some operations, + # which requires CHIPSET >= SM8550. + QNN_CHIPSET=SM8550 + ;; + *) + echo "Unsupported model $MODEL_NAME" + exit 1 + ;; + esac + + "${PYTHON_EXECUTABLE}" -m examples.qualcomm.${SCRIPT_FOLDER}.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --ci --compile_only $EXTRA_FLAGS EXPORTED_MODEL=$(find "./${EXPORT_SCRIPT}" -type f -name "${MODEL_NAME}*.pte" -print -quit) } diff --git a/.ci/scripts/tests/test_get_benchmark_analysis_data.py b/.ci/scripts/tests/test_get_benchmark_analysis_data.py new file mode 100644 index 00000000000..673452ab481 --- /dev/null +++ b/.ci/scripts/tests/test_get_benchmark_analysis_data.py @@ -0,0 +1,903 @@ +import importlib.util +import os +import sys +import tempfile +import unittest +from unittest.mock import MagicMock, mock_open, patch + +import pandas as pd + + +class TestBenchmarkAnalysis(unittest.TestCase): + @classmethod + def setUpClass(cls): + script_path = os.path.join( + ".ci", "scripts", "benchmark_tooling", "get_benchmark_analysis_data.py" + ) + spec = importlib.util.spec_from_file_location( + "get_benchmark_analysis_data", script_path + ) + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module # Register before execution + spec.loader.exec_module(module) + cls.module = module + + """Test the validate_iso8601_no_ms function.""" + + def test_valid_iso8601(self): + """Test with valid ISO8601 format.""" + valid_date = "2025-06-01T00:00:00" + result = self.module.validate_iso8601_no_ms(valid_date) + self.assertEqual(result, valid_date) + + def test_invalid_iso8601(self): + """Test with invalid ISO8601 format.""" + invalid_dates = [ + "2025-06-01", # Missing time + "2025-06-01 00:00:00", # Space instead of T + "2025-06-01T00:00:00.000", # With milliseconds + "not-a-date", # Not a date at all + ] + for invalid_date in invalid_dates: + with self.subTest(invalid_date=invalid_date): + with self.assertRaises(self.module.argparse.ArgumentTypeError): + self.module.validate_iso8601_no_ms(invalid_date) + + def test_output_type_values(self): + """Test that OutputType has the expected values.""" + self.assertEqual(self.module.OutputType.EXCEL.value, "excel") + self.assertEqual(self.module.OutputType.PRINT.value, "print") + self.assertEqual(self.module.OutputType.CSV.value, "csv") + self.assertEqual(self.module.OutputType.JSON.value, "json") + self.assertEqual(self.module.OutputType.DF.value, "df") + + def setUp(self): + """Set up test fixtures.""" + self.maxDiff = None + + self.fetcher = self.module.ExecutorchBenchmarkFetcher( + env="prod", disable_logging=True + ) + + # Sample data for testing + self.sample_data_1 = [ + { + "groupInfo": { + "model": "llama3", + "backend": "qlora", + "device": "Iphone 15 pro max (private)", + "arch": "ios_17", + }, + "rows": [ + { + "workflow_id": 1, + "job_id": 1, + "metadata_info.timestamp": "2025-06-15T15:00:00Z", + "metric_1": 2.0, + }, + { + "workflow_id": 2, + "job_id": 2, + "metadata_info.timestamp": "2025-06-15T14:00:00Z", + "metric_1": 3.0, + }, + ], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "xnnpack_q8", + "device": "s22_5g", + "arch": "android_13", + }, + "rows": [ + { + "workflow_id": 3, + "job_id": 3, + "metadata_info.timestamp": "2025-06-15T17:00:00Z", + "metric_1": 2.0, + }, + { + "workflow_id": 4, + "job_id": 5, + "metadata_info.timestamp": "2025-06-15T14:00:00Z", + "metric_1": 3.0, + }, + ], + }, + ] + + self.sample_data_2 = [ + { + "groupInfo": { + "model": "llama3", + "backend": "qlora", + "device": "Iphone 15 pro max (private)", + "arch": "ios_17.4.3", + }, + "rows": [ + { + "workflow_id": 1, + "job_id": 1, + "metadata_info.timestamp": "2025-06-15T15:00:00Z", + "metric_1": 2.0, + }, + { + "workflow_id": 2, + "job_id": 2, + "metadata_info.timestamp": "2025-06-15T14:00:00Z", + "metric_1": 3.0, + }, + ], + }, + { + "groupInfo": { + "model": "llama3", + "backend": "qlora", + "device": "Iphone 15 pro max", + "arch": "ios_17.4.3", + }, + "rows": [ + { + "workflow_id": 6, + "job_id": 6, + "metadata_info.timestamp": "2025-06-15T17:00:00Z", + "metric_1": 1.0, + }, + { + "workflow_id": 8, + "job_id": 8, + "metadata_info.timestamp": "2025-06-15T14:00:00Z", + "metric_1": 1.0, + }, + ], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "xnnpack_q8", + "device": "s22_5g", + "arch": "android_13", + }, + "rows": [ + { + "workflow_id": 3, + "job_id": 3, + "metadata_info.timestamp": "2025-06-15T17:00:00Z", + "metric_1": 2.0, + }, + { + "workflow_id": 4, + "job_id": 5, + "metadata_info.timestamp": "2025-06-15T14:00:00Z", + "metric_1": 3.0, + }, + ], + }, + ] + + def test_init(self): + """Test initialization of ExecutorchBenchmarkFetcher.""" + self.assertEqual(self.fetcher.env, "prod") + self.assertEqual(self.fetcher.base_url, "https://hud.pytorch.org") + self.assertEqual( + self.fetcher.query_group_table_by_fields, + ["model", "backend", "device", "arch"], + ) + self.assertEqual( + self.fetcher.query_group_row_by_fields, + ["workflow_id", "job_id", "metadata_info.timestamp"], + ) + self.assertTrue(self.fetcher.disable_logging) + self.assertEqual(self.fetcher.matching_groups, {}) + + def test_get_base_url(self): + """Test _get_base_url method.""" + self.assertEqual(self.fetcher._get_base_url(), "https://hud.pytorch.org") + + # Test with local environment + local_fetcher = self.module.ExecutorchBenchmarkFetcher(env="local") + self.assertEqual(local_fetcher._get_base_url(), "http://localhost:3000") + + def test_normalize_string(self): + """Test normalize_string method.""" + test_cases = [ + ("Test String", "test_string"), + ("test_string", "test_string"), + ("test string", "test_string"), + ("test--string", "test_string"), + ("test (private)", "test"), + ("test@#$%^&*", "test_"), + ] + + for input_str, expected in test_cases: + with self.subTest(input_str=input_str): + result = self.fetcher.normalize_string(input_str) + self.assertEqual(result, expected) + + @patch("requests.get") + def test_fetch_execu_torch_data_success(self, mock_get): + """Test _fetch_execu_torch_data method with successful response.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = self.sample_data_1 + mock_get.return_value = mock_response + + result = self.fetcher._fetch_execu_torch_data( + "2025-06-01T00:00:00", "2025-06-02T00:00:00" + ) + + self.assertEqual(result, self.sample_data_1) + mock_get.assert_called_once() + + @patch("requests.get") + def test_fetch_execu_torch_data_failure(self, mock_get): + """Test _fetch_execu_torch_data method with failed response.""" + mock_response = MagicMock() + mock_response.status_code = 404 + mock_response.text = "Not Found" + mock_get.return_value = mock_response + + result = self.fetcher._fetch_execu_torch_data( + "2025-06-01T00:00:00", "2025-06-02T00:00:00" + ) + + self.assertIsNone(result) + mock_get.assert_called_once() + + def test_filter_out_failure_only(self): + """Test _filter_out_failure_only method.""" + test_data = [ + { + "rows": [ + { + "workflow_id": 1, + "job_id": 2, + "metadata_info.timestamp": 3, + "FAILURE_REPORT": "0", + }, + { + "workflow_id": 4, + "job_id": 5, + "metadata_info.timestamp": 6, + "metric": 7.0, + }, + ] + }, + { + "rows": [ + { + "workflow_id": 8, + "job_id": 9, + "metadata_info.timestamp": 10, + "metric": 11.0, + }, + ] + }, + { + "rows": [ + { + "workflow_id": 10, + "job_id": 12, + "metadata_info.timestamp": 3, + "FAILURE_REPORT": "0", + }, + { + "workflow_id": 21, + "job_id": 15, + "metadata_info.timestamp": 6, + "FAILURE_REPORT": "0", + }, + ] + }, + ] + + expected = [ + { + "rows": [ + { + "workflow_id": 4, + "job_id": 5, + "metadata_info.timestamp": 6, + "metric": 7.0, + }, + ] + }, + { + "rows": [ + { + "workflow_id": 8, + "job_id": 9, + "metadata_info.timestamp": 10, + "metric": 11.0, + }, + ] + }, + ] + + result = self.fetcher._filter_out_failure_only(test_data) + self.assertEqual(result, expected) + + def test_filter_public_result(self): + """Test _filter_public_result method.""" + private_list = [ + {"table_name": "model1_backend1"}, + {"table_name": "model2_backend2"}, + ] + + public_list = [ + {"table_name": "model1_backend1"}, + {"table_name": "model3_backend3"}, + ] + + expected = [{"table_name": "model1_backend1"}] + + result = self.fetcher._filter_public_result(private_list, public_list) + self.assertEqual(result, expected) + + @patch( + "get_benchmark_analysis_data.ExecutorchBenchmarkFetcher._fetch_execu_torch_data" + ) + def test_filter_private_results(self, mock_fetch): + """Test filter_private_results method with various filter combinations.""" + # Create test data + test_data = [ + { + "groupInfo": { + "model": "mv3", + "backend": "coreml_fp16", + "device": "Apple iPhone 15 Pro (private)", + "arch": "iOS 18.0", + "total_rows": 10, + "aws_type": "private", + }, + "rows": [{"metric_1": 1.0}], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "test_backend", + "device": "Apple iPhone 15 Pro (private)", + "arch": "iOS 14.1.0", + "total_rows": 10, + "aws_type": "private", + }, + "rows": [{"metric_1": 1.0}], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "xnnpack_q8", + "device": "Samsung Galaxy S22 Ultra 5G (private)", + "arch": "Android 14", + "total_rows": 10, + "aws_type": "private", + }, + "rows": [{"metric_1": 2.0}], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "xnnpack_q8", + "device": "Samsung Galaxy S22 Ultra 5G (private)", + "arch": "Android 13", + "total_rows": 10, + "aws_type": "private", + }, + "rows": [{"metric_1": 2.0}], + }, + { + "groupInfo": { + "model": "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8", + "backend": "llama3_spinquant", + "device": "Apple iPhone 15", + "arch": "iOS 18.0", + "total_rows": 19, + "aws_type": "public", + }, + "rows": [{"metric_1": 2.0}], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "coreml_fp16", + "device": "Apple iPhone 15 Pro Max", + "arch": "iOS 17.0", + "total_rows": 10, + "aws_type": "public", + }, + "rows": [{"metric_1": 2.0}], + }, + { + "groupInfo": { + "model": "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8", + "backend": "test", + "device": "Samsung Galaxy S22 Ultra 5G", + "arch": "Android 14", + "total_rows": 10, + "aws_type": "public", + }, + "rows": [{"metric_1": 2.0}], + }, + ] + + mock_fetch.return_value = test_data + self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00") + + # Test with no filters + empty_filters = self.module.BenchmarkFilters( + models=None, backends=None, devicePoolNames=None + ) + + result = self.fetcher.filter_private_results(test_data, empty_filters) + self.assertEqual(result, test_data) + + # Test with model filter + model_filters = self.module.BenchmarkFilters( + models=["meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"], + backends=None, + devicePoolNames=None, + ) + result = self.fetcher.filter_private_results(test_data, model_filters) + self.assertEqual(len(result), 2) + self.assertTrue( + all( + item["groupInfo"]["model"] + == "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8" + for item in result + ) + ) + + # Test with backend filter + backend_filters = self.module.BenchmarkFilters( + models=None, backends=["coreml_fp16", "test"], devicePoolNames=None + ) + result = self.fetcher.filter_private_results(test_data, backend_filters) + self.assertEqual(len(result), 3) + self.assertTrue( + all( + item["groupInfo"]["backend"] in ["coreml_fp16", "test"] + for item in result + ) + ) + + # Test with device filter + device_filters = self.module.BenchmarkFilters( + models=None, backends=None, devicePoolNames=["samsung_s22_private"] + ) + result = self.fetcher.filter_private_results(test_data, device_filters) + self.assertEqual(len(result), 2) + self.assertTrue( + all( + "Samsung Galaxy S22 Ultra 5G (private)" in item["groupInfo"]["device"] + for item in result + ) + ) + + # Test with combined filters (And logic fails) + combined_filters = self.module.BenchmarkFilters( + models=["meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"], + backends=["xnnpack_q8"], + devicePoolNames=None, + ) + result = self.fetcher.filter_private_results(test_data, combined_filters) + self.assertEqual(len(result), 0) + + # Test with combined filters (And logic success) + combined_filters = self.module.BenchmarkFilters( + models=["mv3"], + backends=None, + devicePoolNames=["apple_iphone_15_private"], + ) + result = self.fetcher.filter_private_results(test_data, combined_filters) + self.assertEqual(len(result), 2) + + @patch( + "get_benchmark_analysis_data.ExecutorchBenchmarkFetcher._fetch_execu_torch_data" + ) + def test_run_without_public_match(self, mock_fetch): + """Test run method.""" + # Setup mocks + mock_fetch.return_value = self.sample_data_1 + # Run the method + self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00") + result = self.fetcher.get_result() + + # Verify results + self.assertEqual(result, {"private": [self.sample_data_1[0]], "public": []}) + self.assertEqual(len(self.fetcher.matching_groups), 2) + self.assertIn("private", self.fetcher.matching_groups) + self.assertIn("public", self.fetcher.matching_groups) + + # Verify mocks were called + mock_fetch.assert_called_once_with("2025-06-01T00:00:00", "2025-06-02T00:00:00") + + @patch( + "get_benchmark_analysis_data.ExecutorchBenchmarkFetcher._fetch_execu_torch_data" + ) + def test_run_with_public_match(self, mock_fetch): + """Test run method.""" + # Setup mocks + mock_fetch.return_value = self.sample_data_2 + + # Run the method + self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00") + result = self.fetcher.get_result() + + # Verify results + self.assertEqual( + result, + {"private": [self.sample_data_2[0]], "public": [self.sample_data_2[1]]}, + ) + self.assertEqual(len(self.fetcher.matching_groups), 2) + self.assertIn("private", self.fetcher.matching_groups) + self.assertIn("public", self.fetcher.matching_groups) + # Verify mocks were called + mock_fetch.assert_called_once_with("2025-06-01T00:00:00", "2025-06-02T00:00:00") + + @patch( + "get_benchmark_analysis_data.ExecutorchBenchmarkFetcher._fetch_execu_torch_data" + ) + def test_run_with_failure_report(self, mock_fetch): + """Test run method.""" + # Setup mocks + mock_data = [ + { + "groupInfo": { + "model": "llama3", + "backend": "qlora", + "device": "Iphone 15 pro max (private)", + "arch": "ios_17.4.3", + }, + "rows": [ + { + "workflow_id": 1, + "job_id": 2, + "metadata_info.timestamp": 3, + "FAILURE_REPORT": "0", + }, + { + "workflow_id": 4, + "job_id": 5, + "metadata_info.timestamp": 6, + "metric": 7.0, + }, + ], + }, + { + "groupInfo": { + "model": "llama3", + "backend": "qlora", + "device": "Iphone 15 pro max", + "arch": "ios_17.4.3", + }, + "rows": [ + { + "workflow_id": 1, + "job_id": 2, + "metadata_info.timestamp": 3, + "FAILURE_REPORT": "0", + }, + { + "workflow_id": 1, + "job_id": 2, + "metadata_info.timestamp": 3, + "FAILURE_REPORT": "0", + }, + ], + }, + ] + + expected_private = { + "groupInfo": { + "model": "llama3", + "backend": "qlora", + "device": "Iphone 15 pro max (private)", + "arch": "ios_17.4.3", + "aws_type": "private", + }, + "rows": [ + { + "workflow_id": 4, + "job_id": 5, + "metadata_info.timestamp": 6, + "metric": 7.0, + }, + ], + "table_name": "llama3-qlora-iphone_15_pro_max-ios_17.4.3", + } + mock_fetch.return_value = mock_data + # Run the method + self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00") + result = self.fetcher.get_result() + # Verify results + self.assertEqual(result.get("private", []), [expected_private]) + self.assertEqual(len(self.fetcher.matching_groups), 2) + self.assertIn("private", self.fetcher.matching_groups) + self.assertIn("public", self.fetcher.matching_groups) + # Verify mocks were called + mock_fetch.assert_called_once_with("2025-06-01T00:00:00", "2025-06-02T00:00:00") + + @patch( + "get_benchmark_analysis_data.ExecutorchBenchmarkFetcher._fetch_execu_torch_data" + ) + def test_run_no_data(self, mock_fetch): + """Test run method when no data is fetched.""" + mock_fetch.return_value = None + + result = self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00") + + self.assertIsNone(result) + self.assertEqual(self.fetcher.matching_groups, {}) + mock_fetch.assert_called_once_with("2025-06-01T00:00:00", "2025-06-02T00:00:00") + + @patch( + "get_benchmark_analysis_data.ExecutorchBenchmarkFetcher._fetch_execu_torch_data" + ) + def test_run_with_filters(self, mock_fetch): + """Test run method with filters.""" + # Setup mock data + mock_data = [ + { + "groupInfo": { + "model": "llama3", + "backend": "qlora", + "device": "Iphone 15 pro max (private)", + "arch": "ios_17", + }, + "rows": [{"metric_1": 1.0}], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "xnnpack_q8", + "device": "s22_5g (private)", + "arch": "android_13", + }, + "rows": [{"metric_1": 2.0}], + }, + { + "groupInfo": { + "model": "mv3", + "backend": "xnnpack_q8", + "device": "s22_5g", + "arch": "android_13", + }, + "rows": [{"metric_1": 3.0}], + }, + ] + mock_fetch.return_value = mock_data + + # Create filters for llama3 model only + filters = self.module.BenchmarkFilters( + models=["llama3"], backends=None, devicePoolNames=None + ) + # Run the method with filters + self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00", filters) + result = self.fetcher.get_result() + print("result1", result) + + # Verify results - should only have llama3 in private results + self.assertEqual(len(result["private"]), 1) + self.assertEqual(result["private"][0]["groupInfo"]["model"], "llama3") + + # Public results should be empty since there's no matching table_name + self.assertEqual(result["public"], []) + + # Test with backend filter + filters = self.module.BenchmarkFilters( + models=None, backends=["xnnpack_q8"], devicePoolNames=None + ) + self.fetcher.run("2025-06-01T00:00:00", "2025-06-02T00:00:00", filters) + result = self.fetcher.get_result() + + print("result", result) + + # Verify results - should only have xnnpack_q8 in private results + self.assertEqual(len(result["private"]), 1) + self.assertEqual(result["private"][0]["groupInfo"]["backend"], "xnnpack_q8") + + # Public results should have the matching xnnpack_q8 entry + self.assertEqual(len(result["public"]), 1) + self.assertEqual(result["public"][0]["groupInfo"]["backend"], "xnnpack_q8") + + def test_to_dict(self): + """Test to_dict method.""" + # Setup test data + self.fetcher.matching_groups = { + "private": self.module.MatchingGroupResult( + category="private", data=[{"key": "private_value"}] + ), + "public": self.module.MatchingGroupResult( + category="public", data=[{"key": "public_value"}] + ), + } + + expected = { + "private": [{"key": "private_value"}], + "public": [{"key": "public_value"}], + } + + result = self.fetcher.to_dict() + self.assertEqual(result, expected) + + def test_to_df(self): + """Test to_df method.""" + # Setup test data + self.fetcher.matching_groups = { + "private": self.module.MatchingGroupResult( + category="private", + data=[{"groupInfo": {"model": "llama3"}, "rows": [{"metric1": 1.0}]}], + ), + } + + result = self.fetcher.to_df() + + self.assertIn("private", result) + self.assertEqual(len(result["private"]), 1) + self.assertIn("groupInfo", result["private"][0]) + self.assertIn("df", result["private"][0]) + self.assertIsInstance(result["private"][0]["df"], pd.DataFrame) + self.assertEqual(result["private"][0]["groupInfo"], {"model": "llama3"}) + + @patch("os.makedirs") + @patch("json.dump") + @patch("builtins.open", new_callable=mock_open) + def test_to_json(self, mock_file, mock_json_dump, mock_makedirs): + """Test to_json method.""" + # Setup test data + self.fetcher.matching_groups = { + "private": self.module.MatchingGroupResult( + category="private", data=[{"key": "value"}] + ), + } + + with tempfile.TemporaryDirectory() as temp_dir: + result = self.fetcher.to_json(temp_dir) + + # Check that the file path is returned + self.assertEqual(result, os.path.join(temp_dir, "benchmark_results.json")) + + # Check that the file was opened for writing + mock_file.assert_called_once_with( + os.path.join(temp_dir, "benchmark_results.json"), "w" + ) + + # Check that json.dump was called with the expected data + mock_json_dump.assert_called_once() + args, _ = mock_json_dump.call_args + self.assertEqual(args[0], {"private": [{"key": "value"}]}) + + @patch("pandas.DataFrame.to_excel") + @patch("pandas.ExcelWriter") + @patch("os.makedirs") + def test_to_excel(self, mock_makedirs, mock_excel_writer, mock_to_excel): + """Test to_excel method.""" + # Setup test data + self.fetcher.matching_groups = { + "private": self.module.MatchingGroupResult( + category="private", + data=[ + { + "groupInfo": {"model": "llama3"}, + "rows": [{"metric1": 1.0}], + "table_name": "llama3_table", + } + ], + ), + } + + # Mock the context manager for ExcelWriter + mock_writer = MagicMock() + mock_excel_writer.return_value.__enter__.return_value = mock_writer + mock_writer.book = MagicMock() + mock_writer.book.add_worksheet.return_value = MagicMock() + mock_writer.sheets = {} + + with tempfile.TemporaryDirectory() as temp_dir: + self.fetcher.to_excel(temp_dir) + + # Check that ExcelWriter was called with the expected path + mock_excel_writer.assert_called_once_with( + os.path.join(temp_dir, "private.xlsx"), engine="xlsxwriter" + ) + + # Check that to_excel was called + mock_to_excel.assert_called_once() + + @patch("os.makedirs") + @patch("builtins.open", new_callable=mock_open) + @patch("pandas.DataFrame.to_csv") + def test_to_csv(self, mock_to_csv, mock_file, mock_makedirs): + """Test to_csv method.""" + # Setup test data + self.fetcher.matching_groups = { + "private": self.module.MatchingGroupResult( + category="private", + data=[{"groupInfo": {"model": "llama3"}, "rows": [{"metric1": 1.0}]}], + ), + } + + with tempfile.TemporaryDirectory() as temp_dir: + self.fetcher.to_csv(temp_dir) + + # Check that the directory was created + mock_makedirs.assert_called() + + # Check that the file was opened for writing + mock_file.assert_called_once() + + # Check that to_csv was called + mock_to_csv.assert_called_once() + + def test_to_output_type(self): + """Test _to_output_type method.""" + # Test with string values + self.assertEqual( + self.fetcher._to_output_type("excel"), self.module.OutputType.EXCEL + ) + self.assertEqual( + self.fetcher._to_output_type("print"), self.module.OutputType.PRINT + ) + self.assertEqual( + self.fetcher._to_output_type("csv"), self.module.OutputType.CSV + ) + self.assertEqual( + self.fetcher._to_output_type("json"), self.module.OutputType.JSON + ) + self.assertEqual(self.fetcher._to_output_type("df"), self.module.OutputType.DF) + + # Test with enum values + self.assertEqual( + self.fetcher._to_output_type(self.module.OutputType.EXCEL), + self.module.OutputType.EXCEL, + ) + + # Test with invalid values + self.assertEqual( + self.fetcher._to_output_type("invalid"), self.module.OutputType.JSON + ) + self.assertEqual(self.fetcher._to_output_type(123), self.module.OutputType.JSON) + + @patch("get_benchmark_analysis_data.ExecutorchBenchmarkFetcher.to_json") + @patch("get_benchmark_analysis_data.ExecutorchBenchmarkFetcher.to_df") + @patch("get_benchmark_analysis_data.ExecutorchBenchmarkFetcher.to_excel") + @patch("get_benchmark_analysis_data.ExecutorchBenchmarkFetcher.to_csv") + def test_output_data(self, mock_to_csv, mock_to_excel, mock_to_df, mock_to_json): + """Test output_data method.""" + # Setup test data + self.fetcher.matching_groups = { + "private": self.module.MatchingGroupResult( + category="private", data=[{"key": "value"}] + ), + } + + # Test PRINT output + result = self.fetcher.output_data(self.module.OutputType.PRINT) + self.assertEqual(result, {"private": [{"key": "value"}]}) + + # Test JSON output + mock_to_json.return_value = "/path/to/file.json" + result = self.fetcher.output_data(self.module.OutputType.JSON) + self.assertEqual(result, {"private": [{"key": "value"}]}) + mock_to_json.assert_called_once_with(".") + + # Test DF output + mock_to_df.return_value = {"private": [{"df": "value"}]} + result = self.fetcher.output_data(self.module.OutputType.DF) + self.assertEqual(result, {"private": [{"df": "value"}]}) + mock_to_df.assert_called_once() + + # Test EXCEL output + result = self.fetcher.output_data(self.module.OutputType.EXCEL) + self.assertEqual(result, {"private": [{"key": "value"}]}) + mock_to_excel.assert_called_once_with(".") + + # Test CSV output + result = self.fetcher.output_data(self.module.OutputType.CSV) + self.assertEqual(result, {"private": [{"key": "value"}]}) + mock_to_csv.assert_called_once_with(".") + + +if __name__ == "__main__": + unittest.main() diff --git a/.ci/scripts/wheel/pre_build_script.sh b/.ci/scripts/wheel/pre_build_script.sh index 2bf8c7c73f0..424529af864 100755 --- a/.ci/scripts/wheel/pre_build_script.sh +++ b/.ci/scripts/wheel/pre_build_script.sh @@ -14,4 +14,4 @@ set -euxo pipefail # which does install them. Though we'd need to disable build isolation to be # able to see the installed torch package. -"${GITHUB_WORKSPACE}/${REPOSITORY}/install_requirements.sh" +"${GITHUB_WORKSPACE}/${REPOSITORY}/install_requirements.sh" --example diff --git a/.github/workflows/android-perf.yml b/.github/workflows/android-perf.yml index e08c536088e..1a6d63f1bd1 100644 --- a/.github/workflows/android-perf.yml +++ b/.github/workflows/android-perf.yml @@ -342,8 +342,8 @@ jobs: git clone https://github.com/huggingface/optimum-executorch pushd optimum-executorch # There is no release yet, for CI stability, always test from the same commit on main - git checkout 1c653dc49812fc431a22312c7295d97005d22e12 - python install_dev.py + git checkout 4c3b18f6cca68c5ccff809131d570062723d7188 + python install_dev.py --skip_override_torch pip list ARGS=( diff --git a/.github/workflows/apple-perf.yml b/.github/workflows/apple-perf.yml index 683ac5170b4..0c03f55f82e 100644 --- a/.github/workflows/apple-perf.yml +++ b/.github/workflows/apple-perf.yml @@ -347,8 +347,8 @@ jobs: git clone https://github.com/huggingface/optimum-executorch pushd optimum-executorch # There is no release yet, for CI stability, always test from the same commit on main - git checkout 1c653dc49812fc431a22312c7295d97005d22e12 - ${CONDA_RUN} python install_dev.py + git checkout 4c3b18f6cca68c5ccff809131d570062723d7188 + ${CONDA_RUN} python install_dev.py --skip_override_torch pip list ARGS=( diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 1cfef0273be..a4996459f8a 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -480,6 +480,32 @@ jobs: PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh PYTHON_EXECUTABLE=python bash .ci/scripts/test_model.sh ${{ matrix.model }} "cmake" "qnn" + test-qnn-optimum-model: + name: test-qnn-optimum-model + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + strategy: + matrix: + dtype: [fp32] + model: [albert, bert, distilbert] # eurobert requires transfomer >= 4.48.0, skip for now + fail-fast: false + with: + runner: linux.2xlarge + docker-image: executorch-ubuntu-22.04-qnn-sdk + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 900 + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool cmake + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh + PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh + PYTHON_EXECUTABLE=python bash .ci/scripts/test_model.sh ${{ matrix.model }} "cmake" "qnn" + test-apple-model: name: test-apple-model uses: pytorch/test-infra/.github/workflows/macos_job.yml@main @@ -571,9 +597,8 @@ jobs: git clone https://github.com/huggingface/optimum-executorch pushd optimum-executorch # There is no release yet, for CI stability, always test from the same commit on main - git checkout 1c653dc49812fc431a22312c7295d97005d22e12 - pip install .[tests] - pip install transformers==4.52.4 + git checkout 4c3b18f6cca68c5ccff809131d570062723d7188 + python install_dev.py --skip_override_torch popd pip list echo "::endgroup::" diff --git a/.gitignore b/.gitignore index c257883ee40..553729e9b68 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,9 @@ xcuserdata/ *.xcworkspace/ *.xcframework/ +# clangd +.cache/ + # misc /.vscode/ *.so diff --git a/CODEOWNERS b/CODEOWNERS index 0be30fbe552..10baed9ede4 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -48,7 +48,7 @@ /extension/flat_tensor @lucylq /extension/gguf_util @larryliu0820 /extension/kernel_util @kimishpatel @manuelcandales @swolchok -/extension/llm @jackzhxng @larryliu0820 @swolchok +/extension/llm @jackzhxng @larryliu0820 @swolchok @mergennachin /extension/memory_allocator @JacobSzwejbka @swolchok /extension/module @shoumikhin /extension/parallel @kimishpatel @swolchok diff --git a/README.md b/README.md index 22a6290e472..8003b25b17b 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ It supports a wide range of models including LLMs (Large Language Models), CV (C Platform Support: - Operating Systems: - iOS - - Mac + - MacOS (ARM64) - Android - Linux - Microcontrollers diff --git a/backends/__init__.py b/backends/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backends/arm/CMakeLists.txt b/backends/arm/CMakeLists.txt index 39a51c56b14..b5e76e778a5 100644 --- a/backends/arm/CMakeLists.txt +++ b/backends/arm/CMakeLists.txt @@ -12,6 +12,8 @@ if(NOT EXECUTORCH_ROOT) set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) endif() +add_compile_options("-Wall" "-Werror") + include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) set(_common_include_directories ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10) diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index e7a68b050c4..440938fd49c 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -5,10 +5,12 @@ from . import arm_pass_utils # noqa +from .arm_pass import ArmPass # noqa # usort: skip +from .add_bias_pass import AddBiasPass # noqa from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa -from .arm_pass import ArmPass # noqa from .broadcast_args_pass import BroadcastArgsPass # noqa +from .cast_bool_to_int8_pass import CastBoolToInt8Pass # noqa from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa from .cast_to_int32_pass import CastToInt32Pass # noqa from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa @@ -20,17 +22,21 @@ from .convert_split_to_slice import ConvertSplitToSlicePass # noqa from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa from .convert_to_clamp import ConvertToClampPass # noqa +from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa from .decompose_div_pass import DecomposeDivPass # noqa from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa from .decompose_gelu_pass import DecomposeGeluPass # noqa +from .decompose_grouped_conv import DecomposeGroupedConv # noqa from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa from .decompose_linear_pass import DecomposeLinearPass # noqa +from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa from .decompose_meandim_pass import DecomposeMeanDimPass # noqa from .decompose_ne_pass import DecomposeNotEqualPass # noqa +from .decompose_round_pass import DecomposeRoundPass # noqa from .decompose_select import DecomposeSelectPass # noqa from .decompose_silu_pass import DecomposeSiluPass # noqa from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa diff --git a/backends/arm/_passes/add_bias_pass.py b/backends/arm/_passes/add_bias_pass.py new file mode 100644 index 00000000000..31c0c0505cb --- /dev/null +++ b/backends/arm/_passes/add_bias_pass.py @@ -0,0 +1,62 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.transforms.utils import create_constant_placeholder + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import PassResult +from torch.export.graph_signature import InputKind + + +class AddBiasPass(ArmPass): + """TOSA requires convolution nodes to have a bias input. + This pass adds a bias input to convolution nodes that do not have one. + The bias is set to zero. + """ + + targeted_ops = (exir_ops.edge.aten.convolution.default,) + + def call(self, graph_module): + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target not in self.targeted_ops: + continue + + if len(node.all_input_nodes) < 3: + modified = True + # bias is missing + weight_node = node.all_input_nodes[1] + output_channels = get_first_fake_tensor(weight_node).shape[0] + # add a node containging zeros + # if quantized, use int32, otherwise use float32 + if ( + "output_qparams" in node.meta + and len(node.meta["output_qparams"]) > 0 + ): + bias_data = torch.zeros(size=(output_channels,), dtype=torch.int32) + else: + bias_data = torch.zeros( + size=(output_channels,), dtype=torch.float32 + ) + + with graph_module.graph.inserting_after(weight_node): + bias_node = create_constant_placeholder( + self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + data=bias_data, + persistent_buffer=True, + name=f"{node.name}_bias", + ) + node.update_arg(2, bias_node) + + if modified: + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 4744845dc2a..f8ead856fbb 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -5,15 +5,12 @@ # pyre-unsafe -from typing import cast import torch from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, - insert_q_dq_pair, ) -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -59,20 +56,10 @@ class AnnotateChannelsLastDimOrder(ExportPass): def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node): """ - returns True for dq and w in the following sequences; + returns True for w in the following sequence; w -> depthwise_conv2d -> ... - w -> dq -> depthwise_conv2d -> ... """ - if node.op == "call_function": - if node.target != dq_op: - return False - prev_node = node.args[0] - if cast(torch.fx.Node, prev_node).op != "placeholder": - return False - if is_consumer_node_depthwise_conv2d(node): - consumer_node = list(node.users)[0] - return consumer_node.args[1] == node - elif node.op == "placeholder": + if node.op == "placeholder": # node is an input, weight or bias node consumer_node = list(node.users)[0] if self.is_weight_node_for_depthwise_conv2d(consumer_node): @@ -129,8 +116,6 @@ def is_channel_reshape(input_shape, output_shape): @staticmethod def insert_input_transpose(node, input_node, graph_module): - quantize = input_node.target == dq_op - q_params = input_node.args[1:] if quantize else None with graph_module.graph.inserting_before(node): permute_node = create_node( graph_module.graph, @@ -143,8 +128,6 @@ def insert_input_transpose(node, input_node, graph_module): else AnnotateChannelsLastDimOrder.NHWC_inverse_order ), ), - quantize=quantize, - q_params=q_params, ) node.replace_input_with(input_node, permute_node) @@ -185,11 +168,6 @@ def insert_output_transpose(node, graph_module): for user in users: user.replace_input_with(node, permute_node) - quantize = node.args[0] == q_op - if quantize: - q_params = node.args[0].args[1:] - insert_q_dq_pair(graph_module.graph, node, q_params) - @staticmethod def _insert_view_transpose( input_shape, output_shape, node, input_node, graph_module @@ -225,10 +203,18 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): - 1D/2D tensors """ for node in graph_module.graph.nodes: - if node.op != "call_function": + # call_function and placeholder allowed due to + # index.Tensor being able to come in as both + if node.op not in ["call_function", "placeholder"]: continue - elif node.target == exir_ops.edge.aten.view_copy.default: + elif node.target in ( + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.index.Tensor, + ): + # For index.Tensor: + # If we want to support 4D indexing tensors this logic + # should be updated. input_node = node.args[0] input_shape = input_node.meta["val"].shape output_shape = node.meta["val"].shape diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 4043a9d7070..9f9168d9238 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -7,13 +7,14 @@ import itertools import operator -from typing import List +from typing import cast, List import torch from executorch.backends.arm._passes.arm_pass_utils import create_node -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, QuantArgs +from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule from torch.fx.passes.utils.source_matcher_utils import get_source_partitions @@ -61,7 +62,7 @@ def call(self, graph_module: GraphModule) -> PassResult: } for partition in matmul_partitions: quantized_input = all( - input_node.target == dq_op for input_node in partition.input_nodes + input_node.target in dq_ops for input_node in partition.input_nodes ) matmul_node = [ node for node in partition.nodes if node.target in matmul_targets @@ -74,17 +75,14 @@ def call(self, graph_module: GraphModule) -> PassResult: input_node = self._match_partition_to_node( node, partition.input_nodes ) - input_node_qargs = QuantArgs.from_operator( - input_node.target, input_node.args - ) # Insert new dq-node just before the mm/bmm with input_node's qparams with graph_module.graph.inserting_before(matmul_node): # Create new dq-node before matmul dq_node = create_node( graph=graph_module.graph, - op_target=dq_op, + op_target=cast(EdgeOpOverload, input_node.target), # type: ignore[arg-type] ) - dq_node.args = (node, *input_node_qargs) + dq_node.args = (node, *input_node.args[1:]) matmul_node.replace_input_with(node, dq_node) for partition_input in partition.input_nodes: @@ -95,19 +93,16 @@ def call(self, graph_module: GraphModule) -> PassResult: graph_module.graph.erase_node(partition_input) partition_output = list(partition.output_nodes[0].users)[0] - quantized_output = partition_output.target == q_op + quantized_output = partition_output.target in q_ops if quantized_output: - output_node_qargs = QuantArgs.from_operator( - partition_output.target, partition_output.args - ) with graph_module.graph.inserting_after(matmul_node): # Create q-node after matmul q_node = create_node( graph=graph_module.graph, - op_target=q_op, + op_target=cast(EdgeOpOverload, partition_output.target), # type: ignore[arg-type] ) matmul_node.replace_all_uses_with(q_node) - q_node.args = (matmul_node, *output_node_qargs) + q_node.args = (matmul_node, *partition_output.args[1:]) # Remove partition output q-node partition_output.replace_all_uses_with( partition_output.all_input_nodes[0] diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 124b68863c9..07a4416cd74 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -7,9 +7,11 @@ # pyre-unsafe from executorch.backends.arm._passes import ( + AddBiasPass, AnnotateChannelsLastDimOrder, AnnotateDecomposedMatmulPass, BroadcastArgsPass, + CastBoolToInt8Pass, CastInt64BuffersToInt32Pass, CastToInt32Pass, ComputeConstantOpsAOT, @@ -23,17 +25,21 @@ ConvertSplitToSlicePass, ConvertSqueezesToViewPass, ConvertToClampPass, + DecomposeAvgPool2d, DecomposeCosineSimilarityPass, DecomposeDivPass, DecomposeEmbeddingPass, DecomposeGeluPass, + DecomposeGroupedConv, DecomposeGroupNormPass, DecomposeLayerNormPass, DecomposeLeakyReLUPass, DecomposeLinearPass, DecomposeLinearVectorNormPass, + DecomposeMaxPool2DPass, DecomposeMeanDimPass, DecomposeNotEqualPass, + DecomposeRoundPass, DecomposeSelectPass, DecomposeSiluPass, DecomposeSoftmaxPass, @@ -62,7 +68,6 @@ UnsqueezeBeforeRepeatPass, UnsqueezeScalarPlaceholdersPass, ) - from executorch.backends.arm.tosa_specification import ( TosaLoweringContext, TosaSpecification, @@ -92,7 +97,6 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(RemoveGetItemPass()) self.add_pass(ConvertSplitToSlicePass()) self.add_pass(ConvertMmToBmmPass()) - self.add_pass(DecomposeLinearPass()) self.add_pass(DecomposeLinearVectorNormPass()) self.add_pass( DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) @@ -105,17 +109,21 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul if self.tosa_spec.is_U55_subset: self.add_pass(CastToInt32Pass()) + self.add_pass(CastBoolToInt8Pass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) - self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg] + self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(MatchArgRanksPass(exported_program)) if self.tosa_spec.is_U55_subset: self.add_pass(BroadcastArgsPass()) + self.add_pass(DecomposeLinearPass()) + self.add_pass(DecomposeAvgPool2d()) self.add_pass(ComputeConstantOpsAOT(exported_program)) + self.add_pass(DecomposeGroupedConv()) self.add_pass(RemoveClonePass()) self.add_pass(SizeAdjustConv2DPass()) self.add_pass(ConvertExpandCopyToRepeatPass()) @@ -123,11 +131,13 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) self.add_pass(DecomposeSumPass()) self.add_pass(Conv1dUnsqueezePass()) + self.add_pass(DecomposeMaxPool2DPass()) self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) self.add_pass(FuseViewCopyTransform()) self.add_pass(FuseConstantArgsPass(exported_program)) + self.add_pass(AddBiasPass(exported_program)) self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) @@ -137,8 +147,10 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul return self._transform(exported_program.graph_module) def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: + self.add_pass(DecomposeRoundPass()) self.add_pass(DecomposeSqrtPass()) self.add_pass(ConvertIntPowToMuls()) + self.add_pass(CastBoolToInt8Pass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI()) self.add_pass(DecomposeEmbeddingPass()) self.add_pass(FuseQuantizedActivationPass()) @@ -166,12 +178,14 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) - self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg] + self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(MatchArgRanksPass(exported_program)) + self.add_pass(DecomposeAvgPool2d()) self.add_pass(ComputeConstantOpsAOT(exported_program)) + self.add_pass(DecomposeGroupedConv()) self.add_pass(RemoveClonePass()) self.add_pass(SizeAdjustConv2DPass()) self.add_pass(ConvertExpandCopyToRepeatPass()) @@ -179,11 +193,13 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) self.add_pass(DecomposeSumPass()) self.add_pass(Conv1dUnsqueezePass()) + self.add_pass(DecomposeMaxPool2DPass()) self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) self.add_pass(FuseViewCopyTransform()) self.add_pass(FuseConstantArgsPass(exported_program)) + self.add_pass(AddBiasPass(exported_program)) self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) self.add_pass(AnnotateChannelsLastDimOrder()) @@ -216,6 +232,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(InsertCastForOpsWithInt64InputPass()) self.add_pass(DecomposeEmbeddingPass()) self.add_pass(DecomposeScaledDotProductAttention()) + self.add_pass(DecomposeRoundPass()) + self.add_pass(CastBoolToInt8Pass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(ScalarsToAttributePass()) self.add_pass(DecomposeGroupNormPass()) @@ -229,6 +247,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeLinearVectorNormPass()) self.add_pass(DecomposeSqrtPass()) self.add_pass(DecomposeSiluPass()) + self.add_pass(DecomposeAvgPool2d()) if self.tosa_spec.is_U55_subset: # Numerically stable softmax uses amax which is not supported on Ethos-U55 diff --git a/backends/arm/_passes/cast_bool_to_int8_pass.py b/backends/arm/_passes/cast_bool_to_int8_pass.py new file mode 100644 index 00000000000..1352671b01e --- /dev/null +++ b/backends/arm/_passes/cast_bool_to_int8_pass.py @@ -0,0 +1,58 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool as input +# If input/output is bool lest add a cast/conversion pass before/after to/from int8. + +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class CastBoolToInt8Pass(ExportPass): + """Casts the input to int8 if it is not already and casts back the output to the original input dtype.""" + + targeted_ops = { + exir_ops.edge.aten.bitwise_and.Tensor, + exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.bitwise_xor.Tensor, + } + + def call_operator(self, op, args, kwargs, meta): + if op not in self.targeted_ops: + return super().call_operator(op, args, kwargs, meta) + + new_args: list = [] + did_cast = False + for arg in args: + if arg.data.dtype == torch.bool: + new_args.append( + super().call_operator( + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + (arg,), + {"dtype": torch.int8}, + meta, + ) + ) + did_cast = True + else: + new_args.append(arg) + + output = super().call_operator( + op, + tuple(new_args), + {}, + meta, + ) + + if did_cast: + output = super().call_operator( + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + (output,), + {"dtype": args[0].data.dtype}, + meta, + ) + return output diff --git a/backends/arm/_passes/cast_int64_pass.py b/backends/arm/_passes/cast_int64_pass.py index 87512f9fb3c..0cdd0422b61 100644 --- a/backends/arm/_passes/cast_int64_pass.py +++ b/backends/arm/_passes/cast_int64_pass.py @@ -35,6 +35,8 @@ def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node): def _to_int32(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: + if len(node.users) == 0: + continue fake_tensor = node.meta["val"] if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor): continue diff --git a/backends/arm/_passes/decompose_avg_pool2d.py b/backends/arm/_passes/decompose_avg_pool2d.py new file mode 100644 index 00000000000..0eb3ce34ecd --- /dev/null +++ b/backends/arm/_passes/decompose_avg_pool2d.py @@ -0,0 +1,121 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.operators.operator_validation_utils import ( + adjust_pooling_pad_if_needed, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +edge_div_ops = (exir_ops.edge.aten.avg_pool2d.default,) +aten_div_ops = (torch.ops.aten.avg_pool2d.default,) + + +def get_decomposition(op) -> tuple: + if op in edge_div_ops: + return ( + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.mul.Tensor, + ) + if op in aten_div_ops: + return ( + torch.ops.aten.full.default, + torch.ops.aten.cat.default, + torch.ops.aten.avg_pool2d.default, + torch.ops.aten.mul.Tensor, + ) + raise RuntimeError(f"Can't get div decomposition for op {op}") + + +class DecomposeAvgPool2d(ExportPass): + """ """ + + def call_operator(self, op, args, kwargs, meta): + if op not in (edge_div_ops + aten_div_ops): + return super().call_operator(op, args, kwargs, meta) + + full_op, cat_op, avgpool_op, mul_op = get_decomposition(op) + + x = args[0] + kernel_h, kernel_w = args[1] + kernel_size = kernel_h * kernel_w + stride_h, stride_w = args[2] + pad_h, pad_w = new_pad_h, new_pad_w = args[3] if len(args) > 3 else (0, 0) + ceil_mode = args[4] if len(args) > 4 else False + count_include_pad = args[5] if len(args) > 5 else True + divisor_override = args[6] if len(args) > 6 else None + + n, c, h, w = x.data.shape + post_pad_w, post_pad_h = (0, 0) + + # Count_include_pad == False means that we use a different divisor for edge elements + # When divisor_override is set, this will be overriden anyways. + # It is easier to replace a constant divisor, so set count_include_pad == True + if divisor_override is not None: + count_include_pad = True + + # Add width padding manually if count_include_pad + if count_include_pad and pad_w > 0: + pre_pad_shape = [n, c, h, pad_w] + pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta) + + if ceil_mode and divisor_override is None: + post_pad_w = pad_w + else: + post_pad_w = adjust_pooling_pad_if_needed( + w, kernel_w, stride_w, pad_w, ceil_mode + ) + + if post_pad_w > 0: + post_pad_shape = [n, c, h, post_pad_w] + post_pad = super().call_operator( + full_op, (post_pad_shape, 0.0), kwargs, meta + ) + cat_nodes = [pre_pad, x, post_pad] + else: + cat_nodes = [pre_pad, x] + + x = super().call_operator(cat_op, (cat_nodes, 3), kwargs, meta) + new_pad_w = 0 + + # Add height padding manually if count_include_pad + if count_include_pad and pad_h > 0: + pre_pad_shape = [n, c, pad_h, w + pad_w + post_pad_w] + pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta) + + if ceil_mode and divisor_override is None: + post_pad_h = pad_h + else: + post_pad_h = adjust_pooling_pad_if_needed( + h, kernel_h, stride_h, pad_h, ceil_mode + ) + + if post_pad_h > 0: + post_pad_shape = [n, c, post_pad_h, w + pad_w + post_pad_w] + post_pad = super().call_operator( + full_op, (post_pad_shape, 0.0), kwargs, meta + ) + cat_nodes = [pre_pad, x, post_pad] + else: + cat_nodes = [pre_pad, x] + + x = super().call_operator(cat_op, (cat_nodes, 2), kwargs, meta) + new_pad_h = 0 + + avgpool_args = (x, args[1], args[2], [new_pad_h, new_pad_w], ceil_mode, False) + x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta) + + # Multiply by factor (kernel_size / divisor_override) if divisor_override + if divisor_override is not None and divisor_override != kernel_size: + override_multiplier = super().call_operator( + full_op, ([1, 1, 1, 1], kernel_size / divisor_override), kwargs, meta + ) + x = super().call_operator(mul_op, (x, override_multiplier), kwargs, meta) + + return x diff --git a/backends/arm/_passes/decompose_grouped_conv.py b/backends/arm/_passes/decompose_grouped_conv.py new file mode 100644 index 00000000000..de96af54adc --- /dev/null +++ b/backends/arm/_passes/decompose_grouped_conv.py @@ -0,0 +1,134 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from copy import copy + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class DecomposeGroupedConv(ExportPass): + """ + Splits a grouped convolution which is not supported by TOSA into multiple + convolutions using slice->conv->cat. + + Before pass: + x = conv(input, weight, bias, groups = 2) + + After pass: + input1 = slice(input) + weight1 = slice(weight) + bias1 = slice(bias) + x1 = conv(input1, weight1, bias1) + + input2 = slice(input) + weight2 = slice(weight) + bias2 = slice(bias) + x2 = conv(input2, weight2, bias2) + + x = cat(x1, x2) + """ + + @staticmethod + def _get_decomposition(op): + match op: + case exir_ops.edge.aten.convolution.default: + return ( + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.cat.default, + ) + case torch.ops.aten.conv2d.default: + return ( + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.conv2d.default, + torch.ops.aten.cat.default, + ) + case _: + raise RuntimeError("Unvalid op for grouped conv decomposition.") + + def call_operator(self, op, args, kwargs, meta): + if op == exir_ops.edge.aten.convolution.default: + groups = args[8] + transposed = args[6] + elif op == torch.ops.aten.conv2d.default: + groups = args[6] + transposed = False + else: + return super().call_operator(op, args, kwargs, meta) + + if groups == 1 or transposed: + return super().call_operator(op, args, kwargs, meta) + + input_node = args[0] + if input_node.data.shape[1] == groups: + # This is a depthwise convolution which is handled elsewhere + return super().call_operator(op, args, kwargs, meta) + + weight_node = args[1] + bias_node = args[2] + + input_slice_size = weight_node.data.shape[1] + output_slice_size = weight_node.data.shape[0] // groups + + no_q_dq_meta = copy(meta) + no_q_dq_meta.data = {} + no_q_dq_meta.data = {} + + slice_op, conv_op, cat_op = DecomposeGroupedConv._get_decomposition(op) + + input_slices = [] + for i in range(groups): + start_index = i * input_slice_size + stop_index = (i + 1) * input_slice_size + slice_args = (input_node, 1, start_index, stop_index) + + input_slices.append( + super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta) + ) + + filter_slices = [] + for i in range(groups): + start_index = i * output_slice_size + stop_index = (i + 1) * output_slice_size + slice_args = (weight_node, 0, start_index, stop_index) + + filter_slices.append( + super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta) + ) + + bias_slices = [] + for i in range(groups): + if bias_node is None: + bias_slices.append(None) + else: + + start_index = i * output_slice_size + stop_index = (i + 1) * output_slice_size + slice_args = (bias_node, 0, start_index, stop_index) + + bias_slices.append( + super().call_operator(slice_op, slice_args, kwargs, no_q_dq_meta) + ) + + output_slices = [] + for input_slice, filter_slice, bias_slice in zip( + input_slices, filter_slices, bias_slices + ): + + if op == exir_ops.edge.aten.convolution.default: + conv_args = (input_slice, filter_slice, bias_slice, *args[3:8], 1) + elif op == torch.ops.aten.conv2d.default: + conv_args = (input_slice, filter_slice, bias_slice, *args[3:6], 1) + else: + raise RuntimeError("Unvalid op for grouped conv decomposition.") + + output_slices.append( + super().call_operator(conv_op, conv_args, kwargs, meta) + ) + + cat_args = (output_slices, 1) + return super().call_operator(cat_op, cat_args, kwargs, no_q_dq_meta) diff --git a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py index 78cb0deae62..9f036c0524f 100644 --- a/backends/arm/_passes/decompose_linalg_vector_norm_pass.py +++ b/backends/arm/_passes/decompose_linalg_vector_norm_pass.py @@ -51,10 +51,12 @@ def call_operator(self, op, args, kwargs, meta): f"is not supported for linalg_vector_norm operator" ) + # Sum over all dimensions if dim is None if norm_dim is None: - raise ValueError("The norm_dim for linalg_vector_norm is None.") - - dims = [norm_dim] if isinstance(norm_dim, int) else list(norm_dim) + rank = input_tensor.data.dim() + dims = list(range(rank)) + else: + dims = [norm_dim] if isinstance(norm_dim, int) else list(norm_dim) # Decomposition based on norm order. if norm_order == 1: diff --git a/backends/arm/_passes/decompose_linear_pass.py b/backends/arm/_passes/decompose_linear_pass.py index 42c2c8c6be9..14baf49bcb2 100644 --- a/backends/arm/_passes/decompose_linear_pass.py +++ b/backends/arm/_passes/decompose_linear_pass.py @@ -1,5 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -7,16 +6,16 @@ # pyre-unsafe import numpy as np +from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import ( create_node, get_first_fake_tensor, ) -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.pass_base import PassResult -class DecomposeLinearPass(ExportPass): +class DecomposeLinearPass(ArmPass): """ This pass decomposes linear into a Conv2D with the required view operations. linear(x, weights, bias) becomes: @@ -24,7 +23,6 @@ class DecomposeLinearPass(ExportPass): weights_reshaped = view(weights) conv2d = conv2d(x_reshaped, weights_reshaped, bias) output = view(conv2d) - It also inserts q/dq pairs if the linear node was quantized. """ def call(self, graph_module): @@ -47,35 +45,22 @@ def call(self, graph_module): weights_reshaped_shape = [weights_shape[0], weights_shape[1], 1, 1] with graph_module.graph.inserting_before(node): - quantize = input.op == "call_function" and input.target == dq_op - q_params = input.args[1:] if quantize else None # Reshape input to 4D with shape (N, Ci, 1, 1) input_reshaped = create_node( graph=graph_module.graph, op_target=exir_ops.edge.aten.view_copy.default, args=(input, input_reshaped_shape), kwargs={}, - quantize=quantize, - q_params=q_params, ) - quantize = weights.op == "call_function" and weights.target == dq_op - q_params = weights.args[1:] if quantize else None # Reshape weights to 4D with shape (Co, Ci, 1, 1) weights_reshaped = create_node( graph=graph_module.graph, op_target=exir_ops.edge.aten.view_copy.default, args=(weights, weights_reshaped_shape), kwargs={}, - quantize=quantize, - q_params=q_params, ) - consumer_node = list(node.users)[0] - quantize = ( - consumer_node.op == "call_function" and consumer_node.target == q_op - ) - q_params = consumer_node.args[1:] if quantize else None conv = create_node( graph=graph_module.graph, op_target=exir_ops.edge.aten.convolution.default, @@ -91,8 +76,7 @@ def call(self, graph_module): 1, # groups ), kwargs={}, - quantize=quantize, - q_params=q_params, + from_node=node, ) with graph_module.graph.inserting_after(conv): diff --git a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py b/backends/arm/_passes/decompose_maxpool2d_with_dilation.py new file mode 100644 index 00000000000..ff6db260099 --- /dev/null +++ b/backends/arm/_passes/decompose_maxpool2d_with_dilation.py @@ -0,0 +1,206 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import operator + +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops + +# We'll decompose only the EXIR edge max_pool2d ops when dilation > 1 +EDGE_MAXPOOL2D = ( + exir_ops.edge.aten.max_pool2d.default, + exir_ops.edge.aten.max_pool2d_with_indices.default, +) + + +class DecomposeMaxPool2DPass(ArmPass): + """ + Decompose dilated max_pool2d (EXIR edge ops) into space-to-batch -> maxpool -> batch-to-space. + """ + + def call_operator(self, op, args, kwargs, meta): + # Only intercept EXIR edge max_pool2d ops + if op not in EDGE_MAXPOOL2D: + return super().call_operator(op, args, kwargs, meta) + + # detect whether indices variant + is_with_indices = op is exir_ops.edge.aten.max_pool2d_with_indices.default + + # Normalize missing trailing args to their defaults + x = args[0] + kernel_size = args[1] + stride = args[2] + padding = args[3] if len(args) >= 4 else 0 + dilation = args[4] if len(args) >= 5 else 1 + ceil_mode = args[5] if len(args) == 6 else False + + # Normalize attributes + pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding + d_h, d_w = (dilation, dilation) if isinstance(dilation, int) else dilation + k_h, k_w = ( + (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + ) + s_h, s_w = (stride, stride) if isinstance(stride, int) else stride + + # If no dilation: call EXIR edge op + if d_h == 1 and d_w == 1: + minimal_args = [x, kernel_size, stride, padding, dilation, ceil_mode] + return super().call_operator(op, tuple(minimal_args), {}, meta) + + # Compute padded and packed dimensions for dilation > 1 + N, C, H, W = x.data.size() + ph, pw = pad_h, pad_w + ph2, pw2 = pad_h, pad_w + H_pad = H + ph + ph2 + W_pad = W + pw + pw2 + H_pack = (H_pad + d_h - 1) // d_h + W_pack = (W_pad + d_w - 1) // d_w + extra_h = 0 if H_pack < k_h else (s_h - ((H_pack - k_h) % s_h)) % s_h + extra_w = 0 if W_pack < k_w else (s_w - ((W_pack - k_w) % s_w)) % s_w + ph2 += extra_h * d_h + pw2 += extra_w * d_w + + # 1) Pad via EXIR edge pad (preserves dtype) + pad_edge = exir_ops.edge.aten.constant_pad_nd.default + pads = [pw, pw2, ph, ph2, 0, 0, 0, 0] + x_pad = super().call_operator( + pad_edge, + (x, pads, 0), + {}, + meta, + ) + + # 2) Space-to-batch: reshape and permute + x2 = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (x_pad, [N, C, H_pack, d_h, W_pack, d_w]), + {}, + meta, + ) + x2 = super().call_operator( + exir_ops.edge.aten.permute_copy.default, + (x2, [3, 5, 0, 1, 2, 4]), + {}, + meta, + ) + x2 = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (x2, [N * d_h * d_w, C, H_pack, W_pack]), + {}, + meta, + ) + + # 3) Core pooling on packed tensor + pool_edge_op = ( + exir_ops.edge.aten.max_pool2d_with_indices.default + if is_with_indices + else exir_ops.edge.aten.max_pool2d.default + ) + pool_args = (x2, (k_h, k_w), (s_h, s_w), (0, 0), 1, ceil_mode) + pool_out = super().call_operator( + pool_edge_op, + pool_args, + {}, + meta, + ) + + # Unpack pooled result + if is_with_indices: + pooled_proxy = super().call_operator( + operator.getitem, + (pool_out, 0), + {}, + meta, + ) + indices_proxy = super().call_operator( + operator.getitem, + (pool_out, 1), + {}, + meta, + ) + pooled_fake, _ = pool_out.data + else: + pooled_proxy = pool_out + pooled_fake = pool_out.data + indices_proxy = None + + _, C_out, H_out, W_out = pooled_fake.shape + + # 4) Batch-to-space: reshape and permute back + out = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (pooled_proxy, [d_h, d_w, N, C_out, H_out, W_out]), + {}, + meta, + ) + out = super().call_operator( + exir_ops.edge.aten.permute_copy.default, + (out, [2, 3, 4, 0, 5, 1]), + {}, + meta, + ) + # now flatten back into (N, C, H_out*d_h, W_out*d_w) + out = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (out, [N, C_out, H_out * d_h, W_out * d_w]), + {}, + meta, + ) + + # 5) Final crop + S_top = ph // d_h + (1 if ph % d_h else 0) + S_left = pw // d_w + (1 if pw % d_w else 0) + S_top = max(0, min(S_top, H_out * d_h - H)) + S_left = max(0, min(S_left, W_out * d_w - W)) + out = super().call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + (out, 2, S_top, S_top + H), + {}, + meta, + ) + out = super().call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + (out, 3, S_left, S_left + W), + {}, + meta, + ) + + if is_with_indices: + # Reconstruct indices + idx = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (indices_proxy, [d_h, d_w, N, C_out, H_out, W_out]), + {}, + meta, + ) + idx = super().call_operator( + exir_ops.edge.aten.permute_copy.default, + (idx, [2, 3, 4, 0, 5, 1]), + {}, + meta, + ) + idx = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (idx, [N, C_out, H_out * d_h, W_out * d_w]), + {}, + meta, + ) + idx = super().call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + (idx, 2, S_top, S_top + H), + {}, + meta, + ) + idx = super().call_operator( + exir_ops.edge.aten.slice_copy.Tensor, + (idx, 3, S_left, S_left + W), + {}, + meta, + ) + return out, idx + + return out diff --git a/backends/arm/_passes/decompose_round_pass.py b/backends/arm/_passes/decompose_round_pass.py new file mode 100644 index 00000000000..edfa3817064 --- /dev/null +++ b/backends/arm/_passes/decompose_round_pass.py @@ -0,0 +1,84 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from torch._ops import OpOverload + + +Op = OpOverload | EdgeOpOverload + + +def _get_round_decomposition_ops(op) -> tuple[Op, Op, Op, Op, Op, Op, Op]: + """ + Returns the (full_op, ge_op, add_op, sub_op, floor_op, ceil_op, where_op) for the + given round operation. The ops depend on whether the round op is an aten or edge op. + """ + if op == exir_ops.edge.aten.round.default: + return ( + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.add.Scalar, + exir_ops.edge.aten.sub.Scalar, + exir_ops.edge.aten.floor.default, + exir_ops.edge.aten.ceil.default, + exir_ops.edge.aten.where.self, + ) + elif op == torch.ops.aten.round.default: + return ( + torch.ops.aten.full.default, + torch.ops.aten.ge.Tensor, + torch.ops.aten.add.Scalar, + torch.ops.aten.sub.Scalar, + torch.ops.aten.floor.default, + torch.ops.aten.ceil.default, + torch.ops.aten.where.self, + ) + raise RuntimeError(f"Can't get round decomposition ops for op {op}") + + +class DecomposeRoundPass(ArmPass): + """ + For inputs >= 0, round(x) is equivalent to floor(x + 0.5), and for inputs < 0, + round(x) is equivalent to ceil(x - 0.5). This pass decomposes the round operation into + a sequence of more primitive operations. + Example: + %zero = full((1,), 0.0, dtype=torch.float32) + %is_non_negative = ge(x, %zero) + %plus_half = add(x, 0.5) + %minus_half = sub(x, 0.5) + %floor = floor(%plus_half) + %ceil = ceil(%minus_half) + %result = where(%is_non_negative, %floor, %ceil) + """ + + def call_operator(self, op, args, kwargs, meta, updated=False): + if op not in (exir_ops.edge.aten.round.default, torch.ops.aten.round.default): + return super().call_operator(op, args, kwargs, meta, updated) + x = args[0] + full, ge, add, sub, floor, ceil, where = _get_round_decomposition_ops(op) + zero = super().call_operator( + full, + args=((1,), 0.0), + kwargs={"dtype": torch.float32}, + meta=meta, + updated=True, + ) + is_non_negative = super().call_operator( + ge, (x, zero), kwargs, meta, updated=True + ) + plus_half = super().call_operator(add, (x, 0.5), kwargs, meta, updated=True) + minus_half = super().call_operator(sub, (x, 0.5), kwargs, meta, updated=True) + floor = super().call_operator(floor, (plus_half,), kwargs, meta, updated=True) + ceil = super().call_operator(ceil, (minus_half,), kwargs, meta, updated=True) + return super().call_operator( + where, + (is_non_negative, floor, ceil), + kwargs, + meta, + updated=True, + ) diff --git a/backends/arm/_passes/decompose_sum_pass.py b/backends/arm/_passes/decompose_sum_pass.py index 531b0d72a19..52b9c10c49f 100644 --- a/backends/arm/_passes/decompose_sum_pass.py +++ b/backends/arm/_passes/decompose_sum_pass.py @@ -63,6 +63,11 @@ def call_operator(self, op, args, kwargs, meta): case _: raise ValueError(f"Invalid number of arguments ({len(args)}) provided.") + # If dims is None, sum over all dimensions + if dims is None: + shape = input_node.data.size() + dims = list(range(len(shape))) + view_op, sum_op = _get_sum_decomp(op) for dim in dims: diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 63c57e1bedd..d2c3ea8582d 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -10,7 +10,13 @@ from typing import cast, Dict, Set, Tuple -from executorch.backends.arm.tosa_quant_utils import QuantArgs +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + get_param_tensor, + is_param_node, +) + +from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops, QuantArgs from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -24,9 +30,6 @@ ) from torch.fx import GraphModule, Node -q_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default -dq_op: EdgeOpOverload = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default - def get_input_qparams(node: Node) -> dict[int, QuantArgs]: """ @@ -66,7 +69,7 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]: return output_qparams -class FoldAndAnnotateQParamsPass(ExportPass): +class FoldAndAnnotateQParamsPass(ArmPass): """ A pass that walks the graph and removes any DQ and Q nodes before and after the target node. @@ -96,9 +99,6 @@ class FoldAndAnnotateQParamsPass(ExportPass): """ - def __init__(self) -> None: - super().__init__() - def fold_and_annotate_arg( self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int ) -> None: @@ -109,8 +109,25 @@ def fold_and_annotate_arg( return arg_quant_params = None - if arg.target == dq_op: - arg_quant_params = QuantArgs.from_operator(arg.target, arg.args) + if arg.target in dq_ops: + args = arg.args + scales = args[1] + if ( + isinstance(args[1], Node) + and self.exported_program is not None + and is_param_node(self.exported_program, args[1]) + ): + scales = get_param_tensor(self.exported_program, args[1]) + zps = args[2] + if ( + isinstance(args[2], Node) + and self.exported_program is not None + and is_param_node(self.exported_program, args[2]) + ): + zps = get_param_tensor(self.exported_program, args[2]) + arg_quant_params = QuantArgs.from_operator( + arg.target, (args[0], scales, zps, *args[3:]) + ) # add arg to nodes_to_remove to fold the dq-node nodes_to_remove.add(arg) if input_qparams is not None and input_qparams != arg_quant_params: @@ -120,11 +137,14 @@ def fold_and_annotate_arg( if input_qparams is not None: node.meta["input_qparams"][i] = input_qparams for n in nodes_to_remove: - if n.target != dq_op: - raise RuntimeError(f"Expected {dq_op} dq_op, got {n.target}") + if n.target not in dq_ops: + raise RuntimeError( + f"Expected one of {dq_ops} dq_op, got {n.target}" + ) - n.replace_all_uses_with(n.args[0]) # type: ignore[arg-type] - graph_module.graph.erase_node(n) + node.replace_input_with(n, cast(Node, n.args[0])) + if len(n.users) == 0: + graph_module.graph.erase_node(n) def call(self, graph_module: GraphModule) -> PassResult: @@ -134,7 +154,7 @@ def call(self, graph_module: GraphModule) -> PassResult: if n.op != "call_function": continue # Don't fold chains of quant-ops into each other. - if n.target in (q_op, dq_op): + if n.target in (*q_ops, *dq_ops): continue # Make sure we haven't already set qparams meta information on the node @@ -164,7 +184,7 @@ def call(self, graph_module: GraphModule) -> PassResult: # Copy the users, since we are modifying it. users_copy = copy.copy(n.users) for i, user in enumerate(users_copy): - if user.target != q_op: + if user.target not in q_ops: continue # quantization node found here, store the quantization parameters in meta value @@ -201,7 +221,7 @@ def call(self, graph_module: GraphModule) -> PassResult: # Make sure we have a quantized operator user = list(n.users)[0] - if user.target != q_op: + if user.target not in q_ops: continue qargs = QuantArgs.from_operator(user.target, user.args) diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index 56e124d8d0a..f70614d6231 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -98,7 +98,7 @@ def _fuse_nodes(self, node) -> bool: def call(self, graph_module): modified = False - input_nodes_to_delete = [] + input_nodes_to_maybe_delete = set() for node in graph_module.graph.nodes: if node.op != "call_function": continue @@ -116,26 +116,29 @@ def call(self, graph_module): or torch._export.utils.is_buffer(self.exported_program, input_node) for input_node in input_nodes ) - input_nodes_single_users = ( - len(input_node.users) == 1 for input_node in input_nodes - ) + if not all(input_nodes_constant): + continue - if all(input_nodes_constant) and all(input_nodes_single_users): - try: - did_fuse = self._fuse_nodes(node) - modified |= did_fuse - if did_fuse: - graph_module.recompile() # Recompile needed to catch chains of constant ops - input_nodes_to_delete.extend(input_nodes) - except Exception as e: - logger.warning( - f"\nFailed to fuse constant op {node.name} due to exception:\n{str(e)}" + try: + did_fuse = self._fuse_nodes(node) + if did_fuse: + logger.debug( + f"Fused constant op: {node.name} with placeholder inputs:" + f"{[input_node.name for input_node in input_nodes]}" ) + modified |= did_fuse + graph_module.recompile() # Recompile needed to catch chains of constant ops + input_nodes_to_maybe_delete.update(input_nodes) + except Exception as e: + logger.warning( + f"\nFailed to fuse constant op {node.name} due to exception:\n{str(e)}" + ) if modified: graph_module.graph.eliminate_dead_code() - for input_node in input_nodes_to_delete: - delete_constant_placeholder(self.exported_program, input_node) + for input_node in input_nodes_to_maybe_delete: + if len(input_node.users) == 0: + delete_constant_placeholder(self.exported_program, input_node) graph_module = super().call(graph_module).graph_module diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index cd8cce1b3ea..664a0f8ea6c 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -49,7 +49,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: if tensor2 is None: continue - if torch.equal(tensor1, tensor2): + if ( + tensor1.dtype == tensor2.dtype + and tensor1.shape == tensor2.shape + and torch.allclose(tensor1, tensor2, atol=1e-08) + ): eq_nodes.append(node2) if len(eq_nodes) > 1: diff --git a/backends/arm/_passes/fuse_quantized_activation_pass.py b/backends/arm/_passes/fuse_quantized_activation_pass.py index 13c69bf92f1..f70d6d8755b 100644 --- a/backends/arm/_passes/fuse_quantized_activation_pass.py +++ b/backends/arm/_passes/fuse_quantized_activation_pass.py @@ -6,7 +6,7 @@ # pyre-unsafe import torch -from executorch.backends.arm.tosa_quant_utils import q_op +from executorch.backends.arm.tosa_quant_utils import q_ops, QuantArgs from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import Node @@ -21,11 +21,12 @@ def _is_fuseable_quantized_activation(node: Node): min_val = node.args[1] is_fuseable = min_val == 0 - is_quantized = len(node.users) == 1 and next(iter(node.users)).target == q_op + is_quantized = len(node.users) == 1 and next(iter(node.users)).target in q_ops if is_fuseable and is_quantized: quant_node = next(iter(node.users)) - zp = quant_node.args[2] - qmin = quant_node.args[3] + quant_args = QuantArgs.from_operator(quant_node.target, quant_node.args) + zp = quant_args.zp + qmin = quant_args.qmin return zp == qmin else: return False diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 541638b830e..97b8fb15711 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -9,7 +9,7 @@ import torch from executorch.backends.arm._passes.arm_pass_utils import create_node -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, QuantArgs +from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops, QuantArgs from executorch.exir.pass_base import ExportPass, PassResult from torch import Tensor from torch.fx import GraphModule, Node @@ -94,11 +94,11 @@ def call(self, graph_module: GraphModule) -> PassResult: for node in graph_module.graph.nodes: node = cast(Node, node) - if node.target is not dq_op: + if node.target not in dq_ops: continue # Copy users since we remove them while iterating, modyfing the node.users list. for user in copy(node.users): - if user.target is q_op: + if user.target in q_ops: self.fold_dq_q_to_rescale(node, user, graph_module) modified = True if len(node.users) == 0: diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index aeb9d3bc5eb..402ed0253c0 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -143,9 +143,7 @@ def f(x: torch.Tensor) -> torch.Tensor: start=in_quantargs.qmin, end=in_quantargs.qmax, steps=256, - # use torch.int64 to avoid overflow when dequantizing (subtracting zp). - # e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8) - dtype=torch.int64, + dtype=torch.int8, ) ).to(dtype=torch.int8), 0, @@ -173,6 +171,9 @@ def generate_16_bit_table_values( """ def f(x: torch.Tensor) -> torch.Tensor: + x = x.clamp(in_quantargs.qmin, in_quantargs.qmax).to( + dtype=in_quantargs.dtype + ) # Dont use the 7 LSBs. x = in_quantargs.dequantize_value((x & ~0x7F)) x = torch_op(x) @@ -183,9 +184,8 @@ def f(x: torch.Tensor) -> torch.Tensor: start=in_quantargs.qmin, end=in_quantargs.qmax + 1, steps=513, - # use torch.int64 to avoid overflow when dequantizing (subtracting zp). - # e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8) - dtype=torch.int64, + # use torch.int32 to avoid overflow for end=in_quantargs.qmax + 1. + dtype=torch.int32, ) ) # Calculate how much we need to shift table values to fit in 16 signed bits diff --git a/backends/arm/_passes/match_where_self_arg_dtype_pass.py b/backends/arm/_passes/match_where_self_arg_dtype_pass.py index 154602129f8..fdbd4433bab 100644 --- a/backends/arm/_passes/match_where_self_arg_dtype_pass.py +++ b/backends/arm/_passes/match_where_self_arg_dtype_pass.py @@ -49,7 +49,7 @@ def call(self, graph_module: torch.fx.GraphModule): input_dtype = input_.meta["val"].dtype other_dtype = other_.meta["val"].dtype - target_dtype = torch.float32 + target_dtype = input_dtype if input_dtype != other_dtype: target_dtype = get_largest_dtype(input_dtype, other_dtype) diff --git a/backends/arm/_passes/mm_to_bmm_pass.py b/backends/arm/_passes/mm_to_bmm_pass.py index 34ac7553212..519b755080c 100644 --- a/backends/arm/_passes/mm_to_bmm_pass.py +++ b/backends/arm/_passes/mm_to_bmm_pass.py @@ -12,7 +12,7 @@ get_first_fake_tensor, insert_q_dq_pair, ) -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op +from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import Node @@ -56,7 +56,7 @@ def call(self, graph_module: torch.fx.GraphModule): node.replace_input_with(input_node, unsqueeze_before) # If Quantized we must insert unsqueeze --> q --> dq --> node - if input_node.target == dq_op: + if input_node.target in dq_ops: q_params = input_node.args[1:] insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node) @@ -89,7 +89,7 @@ def call(self, graph_module: torch.fx.GraphModule): user.replace_input_with(bmm_node, squeeze_after) # If quantized, insert mm --> q --> dq --> squeeze - if all(original_user.target == q_op for original_user in original_users): + if all(original_user.target in q_ops for original_user in original_users): q_params = original_users[0].args[1:] insert_q_dq_pair(graph, bmm_node, q_params, from_node=node) diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index bdbbbfd1162..ece26ae4f81 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -10,12 +10,15 @@ # backends. Converts via TOSA as an intermediate form supported by AoT and # JIT compiler flows. # - from typing import List, Optional -from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_specification import ( # type: ignore[import-not-found] + TosaSpecification, +) -from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] + CompileSpec, +) class ArmCompileSpecBuilder: @@ -28,6 +31,7 @@ def __init__(self): def vgf_compile_spec( self, + tosa_spec: TosaSpecification = None, # type: ignore[assignment] compiler_flags: Optional[str] = "", ) -> "ArmCompileSpecBuilder": """ @@ -40,7 +44,33 @@ def vgf_compile_spec( self.compiler_flags = [ compiler_flags, ] - self.tosa_spec = TosaSpecification.create_from_string("TOSA-0.80+MI") + + if tosa_spec is None: + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") + + tosa_version = tosa_spec.version # type: ignore[attr-defined] + tosa_profiles = tosa_spec.profiles # type: ignore[attr-defined] + + if tosa_version.major != 1: + raise ValueError( + "Arm backend only supports converter-backend for TOSA version 1. " + f"Invalid TOSA version: {tosa_version}" + ) + + if not ("FP" or "INT" in tosa_profiles): + raise ValueError( + "Arm backend only supports converter-backend for FP or INT. " + f"Invalid TOSA profile: {tosa_profiles}" + ) + + if len(tosa_profiles) != 1: + raise ValueError( + "For now Arm backend only supports converter-backend for either FP or INT. " + f"Invalid TOSA profile: {tosa_profiles}" + ) + + self.tosa_spec = tosa_spec + return self def ethosu_compile_spec( diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 4a1f1269fe2..2075e0f554f 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -10,6 +10,7 @@ embedding_support, ethos_u55_support, index_select_support, + index_tensor_support, minmax_support, pool_2d_support, reduce_sum_support, diff --git a/backends/arm/operator_support/index_tensor_support.py b/backends/arm/operator_support/index_tensor_support.py new file mode 100644 index 00000000000..7330f98667d --- /dev/null +++ b/backends/arm/operator_support/index_tensor_support.py @@ -0,0 +1,128 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.fx as fx +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_tosa_support_check +class IndexTensorSupported(SupportedTOSAOperatorCheck): + """ + This support check is intended to prevent the partitioning of + currently unsupported usages of the index.Tensor operator. + + 1. Usages where indexing tensors are of rank 4 or higher. + This is due to the AnnotateChannelsLastDimOrder pass and + the rarity of such operation. + Support is possible but would require further changes to the above + pass which can be added at such a time as is necessary. + + 2. Usages where slice, ellipsis or None are present before an indexing tensor: + t[{start}:{end}, indexTensor] - slicing + t[None, indexTensor] - unsqueeze + t[..., indexTensor] - ellipsis + + 3. Usages where the value tensor contains more than int32.max elements + This is due to int32 TOSA limitation and the fact that we flatten out + and accumulate all index tensors. + As such to avoid overflow we reject lowering of this operator if it is + possible for indices to go over the int32 limit. + + Extra information regarding #2: + Pytorch decomposes slice and None usages before they reach aten. + In the case of Slicing and Unsqueeze, Pytorch will add the relevant + operation just before the index.Tensor op. + In the case of Ellipsis no extra operation is added. + + In all three cases Pytorch will insert "None"(s) in the index list + only if the above operations are done on a dimension BEFORE one being indexed. + + When slicing, unsqueeze and ellipsis are done on dimensions after + the ones being indexed, then they do not affect the final output + values, only the shape. Thus None is not passed to the index.Tensor op. + + The purpose of None is to signify to index.Tensor that a dimension + should not be indexed. + In such cases the logic behaves similar to batching along that dimension. + For the sake of simplicity we have not implemented this behavior yet + and thus have put this support check in place to prevent the partitioning + of index.Tensor ops which include None. + + Examples: + #1 - Slice ----------------------------------------------------- + t = torch.randint(25, size(25, 3, 6)) + t[1:5, torch.arange(3)] + + Turns into: (edge pseudo code) + slice_res = ...edge__ops_aten_slice_copy_Tensor(t, dim=0, start=1, end=2) + out = ...edge__ops_aten_index_Tensor(slice_res, [None, torch.arange(3)]) + + #2 - None (Unsqueeze) ------------------------------------------ + t = torch.randint(25, size(25, 3, 6)) + t[None, torch.arange(3)] + + Turns into: edge pseudo code) + unsqueeze_res = ...edge__ops_aten_unsqueeze(t, dim=0) + out = ...edge__ops_aten_index_Tensor(unsqueeze_res, [None, torch.arange(3)]) + + #3 - None (Unsqueeze) After index ------------------------------ + t = torch.randint(25, size(25, 3, 6)) + t[torch.arange(3), None] + + Turns into: edge pseudo code) + unsqueeze_res = ...edge__ops_aten_unsqueeze(t, dim=1) + out = ...edge__ops_aten_index_Tensor(unsqueeze_res, [torch.arange(3)]) + + NB. + With the current implementation of flattening tensors and indices out, + supporting None (Unsqueeze) is simply a matter of ignoring the + None dimension. + This is not the case for Slice and Ellipsis operators, where + the size of the new dimension can be > 1. + + Note that slice ops interleaved between indexes such as: + t[1:3, torch.arange(5), 2:3, torch.arange(3).reshape(3,1)] + are also possible and can result in some unintuitive behaviors + where batching and indexing are mixed together. + """ + + targets = [exir_ops.edge.aten.index.Tensor] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: # type: ignore[override, misc] + indices = node.args[1] + for index in indices: # type: ignore[union-attr] + # Usage 2 guard + if index is None: + return False + + # Usage 1 guard + fake_tensor = get_first_fake_tensor(index) # type: ignore[arg-type] + if len(fake_tensor.size()) > 3: + return False + + # Usage 3 guard + total_vals = math.prod(get_first_fake_tensor(node.args[0]).shape) # type: ignore[arg-type] + if total_vals > torch.iinfo(torch.int32).max: + return False + + return True diff --git a/backends/arm/operator_support/pool_2d_support.py b/backends/arm/operator_support/pool_2d_support.py index 9db58f663d3..677436ddc50 100644 --- a/backends/arm/operator_support/pool_2d_support.py +++ b/backends/arm/operator_support/pool_2d_support.py @@ -12,6 +12,9 @@ register_tosa_support_check, SupportedTOSAOperatorCheck, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + adjust_pooling_pad_if_needed, +) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops @@ -56,25 +59,42 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): input_arg = get_first_fake_tensor(input_arg) shape = input_arg.data.shape # type: ignore[union-attr] + # Calculate padding used in the final TOSA operator kernel = cast(tuple[int, int], node.args[1]) stride = cast(tuple[int, int], node.args[2]) - if len(node.args) > 3: - padding = cast(tuple[int, int], node.args[3]) - # Padding case - if not all(1 <= k <= 8 for k in kernel) and not all( - v == 0 for v in padding - ): - self.reporter.report_reject( - node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}" - ) - return False + padding = cast(tuple[int, int], node.args[3]) if len(node.args) > 3 else (0, 0) + ceil_mode = cast(bool, node.args[4]) if len(node.args) > 4 else False + count_include_pad = cast(bool, node.args[5]) if len(node.args) > 5 else True + divisor_override = cast(int, node.args[6]) if len(node.args) > 6 else None + + # If count_include_pad is True or divior_override is given, padding is applied + # by concating zero-elements rather than setting it in the avg_pool op. + if count_include_pad or divisor_override is not None: + tosa_padding = (0, 0, 0, 0) + # Otherwise, calculate the padding as done in the node visitor else: - if not kernel_check(kernel): - self.reporter.report_reject( - node, - f"Avgpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}", - ) - return False + post_pad_h = adjust_pooling_pad_if_needed( + shape[2], kernel[0], stride[0], padding[0], ceil_mode + ) + post_pad_w = adjust_pooling_pad_if_needed( + shape[3], kernel[1], stride[1], padding[1], ceil_mode + ) + tosa_padding = (padding[0], post_pad_h, padding[1], post_pad_w) + + if not all(1 <= k <= 8 for k in kernel) and not all( + v == 0 for v in tosa_padding + ): + self.reporter.report_reject( + node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}" + ) + return False + + if not kernel_check(kernel): + self.reporter.report_reject( + node, + f"Avgpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}", + ) + return False if not dim_check(shape): self.reporter.report_reject( diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 36ae77d26a3..7a893acaf80 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -24,6 +24,7 @@ EthosU55NotSupported, EthosU55TransposeCheck, ) +from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir import ExportedProgram from executorch.exir.backend.utils import WhyNoPartitionReporter @@ -210,6 +211,7 @@ def is_node_supported( exir_ops.edge.aten.leaky_relu.default, exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.round.default, exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten._log_softmax.default, @@ -228,7 +230,9 @@ def is_node_supported( exir_ops.edge.aten.where.self, operator.getitem, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, exir_ops.edge.aten.constant_pad_nd.default, exir_ops.edge.aten.amax.default, exir_ops.edge.aten.amin.default, @@ -278,6 +282,7 @@ def is_node_supported( exir_ops.edge.aten.ne.Scalar: None, exir_ops.edge.aten.div.Scalar: None, exir_ops.edge.aten.leaky_relu.default: None, + exir_ops.edge.aten.round.default: None, } if node.target in needs_decomp_dict: @@ -298,8 +303,6 @@ class CheckProperQuantization(OperatorSupportBase): activations. """ - dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default - q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default targeted_ops = ( exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.avg_pool2d.default, @@ -349,7 +352,7 @@ def _is_matmul_node_supported( matched_partition = partition if matched_partition is not None: input_quantized = all( - input_node.target == self.dq_op + input_node.target in dq_ops for input_node in matched_partition.input_nodes ) if not input_quantized: @@ -358,7 +361,7 @@ def _is_matmul_node_supported( ) return False output_quantized = all( - output_node_user.target == self.q_op + output_node_user.target in q_ops for output_node_user in matched_partition.output_nodes[0].users ) if not output_quantized: @@ -394,7 +397,7 @@ def is_node_supported( users = node.users output_quantized = all( user.target == operator.getitem - and all(user_user.target == self.q_op for user_user in user.users) + and all(user_user.target in q_ops for user_user in user.users) for user in users ) elif FuseQuantizedActivationPass._is_fuseable_input(node): @@ -408,7 +411,7 @@ def is_node_supported( input_quantized = FuseQuantizedActivationPass._is_fuseable_input(input_node) input_quantized = input_quantized or all( - (input_node.target == self.dq_op) + (input_node.target in dq_ops) or (not get_first_fake_tensor(input_node).dtype.is_floating_point) for input_node in node.all_input_nodes ) @@ -417,9 +420,7 @@ def is_node_supported( self.reporter.report_reject(node, "One or more inputs were not quantized.") return False - all_q_users = all( - (output_node.target == self.q_op) for output_node in node.users - ) + all_q_users = all((output_node.target in q_ops) for output_node in node.users) is_floating_point = get_first_fake_tensor(node).dtype.is_floating_point output_quantized = output_quantized or all_q_users or not is_floating_point diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 1e2620e4533..260299d6423 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -25,6 +25,7 @@ op_ge, op_gt, op_index_select, + op_index_tensor, op_le, op_log, op_lt, diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index d27645839ea..f839ca380ec 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -54,6 +54,11 @@ def _build_generic_avgpool2d( kernel_size_list = inputs[1].special stride_size_list = inputs[2].special + if len(inputs) > 4: + ceil_mode = bool(inputs[4].number) + else: + ceil_mode = False + try: pad_size_list = inputs[3].special pad_size_list = [ @@ -71,12 +76,14 @@ def _build_generic_avgpool2d( kernel_size_list[0], stride_size_list[0], pad_size_list[1], + ceil_mode, ) pad_size_list[3] = adjust_pooling_pad_if_needed( input_tensor.shape[3], kernel_size_list[1], stride_size_list[1], pad_size_list[3], + ceil_mode, ) attr = ts.TosaSerializerAttribute() @@ -105,7 +112,7 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [3, 4, 6]) + validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec @@ -114,10 +121,10 @@ def define_node( accumulator_type = ts.DType.INT32 input_qargs = get_input_qparams(node) - input_zp = input_qargs[0].zp + input_zp = input_qargs[0].get_zp_per_tensor() output_qargs = get_output_qparams(node) - output_zp = output_qargs[0].zp + output_zp = output_qargs[0].get_zp_per_tensor() self._build_generic_avgpool2d( node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type @@ -141,7 +148,7 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [3, 4, 6]) + validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, @@ -192,6 +199,11 @@ def _build_generic_avgpool2d( kernel_size_list = inputs[1].special stride_size_list = inputs[2].special + if len(inputs) > 4: + ceil_mode = bool(inputs[4].number) + else: + ceil_mode = False + try: pad_size_list = inputs[3].special pad_size_list = [ @@ -209,12 +221,14 @@ def _build_generic_avgpool2d( kernel_size_list[0], stride_size_list[0], pad_size_list[1], + ceil_mode, ) pad_size_list[3] = adjust_pooling_pad_if_needed( input_tensor.shape[3], kernel_size_list[1], stride_size_list[1], pad_size_list[3], + ceil_mode, ) attr = ts.TosaSerializerAttribute() @@ -247,7 +261,7 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [3, 4, 6]) + validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec @@ -256,10 +270,10 @@ def define_node( accumulator_type = ts.DType.INT32 input_qargs = get_input_qparams(node) - input_zp = input_qargs[0].zp + input_zp = input_qargs[0].get_zp_per_tensor() output_qargs = get_output_qparams(node) - output_zp = output_qargs[0].zp + output_zp = output_qargs[0].get_zp_per_tensor() self._build_generic_avgpool2d( node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type @@ -286,7 +300,7 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [3, 4, 6]) + validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index 3aa63abe1a0..68b5b363703 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -65,8 +65,8 @@ def define_node( # for a later rescale. if inputs[0].dtype == ts.DType.INT8: input_qparams = get_input_qparams(node) - input0_zp = input_qparams[0].zp - input1_zp = input_qparams[1].zp + input0_zp = input_qparams[0].get_zp_per_tensor() + input1_zp = input_qparams[1].get_zp_per_tensor() bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) bmm_output_name = bmm_result.name else: @@ -88,8 +88,8 @@ def define_node( if output.dtype == ts.DType.INT8: output_qparams = get_output_qparams(node)[0] final_output_scale = ( - input_qparams[0].scale * input_qparams[1].scale # type: ignore[possibly-undefined] # pyre-ignore[61] - ) / output_qparams.scale + input_qparams[0].get_scale_per_tensor() * input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61] + ) / output_qparams.get_scale_per_tensor() build_rescale_v0_80( tosa_fb=tosa_graph, @@ -98,8 +98,8 @@ def define_node( input_node=bmm_result, # type: ignore[possibly-undefined] output_name=output.name, output_type=ts.DType.INT8, - input_zp=0, - output_zp=output_qparams.zp, + input_zp=[0], + output_zp=[output_qparams.get_zp_per_tensor()], is_double_round=False, ) @@ -142,8 +142,8 @@ def define_node( if inputs[0].dtype == ts.DType.INT8: input_qparams = get_input_qparams(node) - input0_zp = input_qparams[0].zp - input1_zp = input_qparams[1].zp + input0_zp = input_qparams[0].get_zp_per_tensor() + input1_zp = input_qparams[1].get_zp_per_tensor() bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) bmm_output_name = bmm_result.name else: @@ -169,8 +169,8 @@ def define_node( if output.dtype == ts.DType.INT8: output_qparams = get_output_qparams(node)[0] final_output_scale = ( - input_qparams[0].scale * input_qparams[1].scale # type: ignore[possibly-undefined] # pyre-ignore[61] - ) / output_qparams.scale + input_qparams[0].get_scale_per_tensor() * input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61] + ) / output_qparams.get_scale_per_tensor() build_rescale( tosa_fb=tosa_graph, @@ -179,7 +179,7 @@ def define_node( input_node=bmm_result, # type: ignore[possibly-undefined] output_name=output.name, output_type=ts.DType.INT8, - input_zp=0, - output_zp=output_qparams.zp, + input_zp=[0], + output_zp=[output_qparams.get_zp_per_tensor()], rounding_mode=RoundingMode.SINGLE_ROUND, ) diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index a566b0fbfa7..3c73e7b32c0 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe +import itertools from typing import Any, List import torch @@ -97,7 +98,7 @@ def define_node( if inputs[0].dtype == ts.DType.INT8: # int8 input requires quantization information input_qparams = get_input_qparams(node) - input_zp = input_qparams[0].zp + input_zp = input_qparams[0].get_zp_per_tensor() attr.ConvAttribute( pad=pad_attr, @@ -108,24 +109,6 @@ def define_node( local_bound=False, ) - # Non-bias case. - if len(node.all_input_nodes) == 2: - # Create a zero bias tensor if not presented - out_channels = weight.shape[0] - bias_name = "bias" + node.name.split("default", 1)[1] - bias_type = output.dtype - if output.dtype == ts.DType.INT8: - # Conv is quantized to int8, but the TOSA operator has - # output type int32, and the bias must be the same type - # as the TOSA output type - bias_type = ts.DType.INT32 - bias = tosa_graph.addConst( - [out_channels], - bias_type, - [0] * out_channels, - name=bias_name, - ) - # The output type is int32 when input type is int8. conv2d_output_name = output.name if output.dtype == ts.DType.INT8: @@ -178,13 +161,22 @@ def define_node( # integer value domain of the next op. Otherwise return float32 output. if inputs[0].dtype == ts.DType.INT8: # Get scale_factor from input, weight, and output. - input_scale = input_qparams[0].scale # type: ignore[possibly-undefined] # pyre-ignore [61] - weight_scale = input_qparams[1].scale # pyre-ignore [61] + input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61] + + per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61] + if per_channel_quant: + weight_scale = input_qparams[1].get_scale_per_channel() + else: + weight_scale = [ + input_qparams[1].get_scale_per_tensor() + ] # pyre-ignore [61] output_qargs = get_output_qparams(node) post_conv2d_scale = [ (inp * w) / out for inp, w, out in zip( - [input_scale], [weight_scale], [output_qargs[0].scale] + itertools.cycle([input_scale]), + weight_scale, + itertools.cycle([output_qargs[0].get_scale_per_tensor()]), ) ] @@ -194,9 +186,9 @@ def define_node( input_node=conv2d_res, # type: ignore[possibly-undefined] output_name=output.name, output_type=output.dtype, - input_zp=0, - output_zp=output_qargs[0].zp, - per_channel=isinstance(weight_scale, torch.Tensor), + input_zp=[0], + output_zp=[output_qargs[0].get_zp_per_tensor()], + per_channel=per_channel_quant, ) # type: ignore[call-arg] @@ -274,7 +266,13 @@ def define_node( if inputs[0].dtype == ts.DType.INT8: # int8 input requires quantization information input_qparams = get_input_qparams(node) - input_zp = input_qparams[0].zp + input_zp = input_qparams[0].get_zp_per_tensor() + + weight_zp = 0 + if inputs[1].dtype == ts.DType.INT8: + # int8 weights requires quantization information + input_qparams = get_input_qparams(node) + weight_zp = input_qparams[1].zp # type: ignore[assignment] # The output type is int32 when input type is int8. conv2d_output_name = output.name @@ -291,27 +289,12 @@ def define_node( [1], output.dtype, [input_zp], name=f"{conv2d_output_name}_input_zp" ) tosa_graph.addConst( - [1], output.dtype, [0], name=f"{conv2d_output_name}_weight_zp" + [1], + output.dtype, + weight_zp, + name=f"{conv2d_output_name}_weight_zp", ) - # Non-bias case. - if len(node.all_input_nodes) == 2: - # Create a zero bias tensor if not presented - out_channels = weight.shape[0] - bias_name = f"{conv2d_output_name}_bias" - bias_type = output.dtype - if output.dtype == ts.DType.INT8: - # Conv is quantized to int8, but the TOSA operator has - # output type int32, and the bias must be the same type - # as the TOSA output type - bias_type = ts.DType.INT32 - bias = tosa_graph.addConst( - [out_channels], - bias_type, - [0] * out_channels, - name=bias_name, - ) - # Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W) in_channels = input.shape[1] out_channels = weight.shape[0] @@ -388,13 +371,21 @@ def define_node( # integer value domain of the next op. Otherwise return float32 output. if inputs[0].dtype == ts.DType.INT8: # Get scale_factor from input, weight, and output. - input_scale = input_qparams[0].scale # type: ignore[possibly-undefined] # pyre-ignore [61] - weight_scale = input_qparams[1].scale # pyre-ignore [61] + input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61] + per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61] + if per_channel_quant: + weight_scale = input_qparams[1].get_scale_per_channel() + else: + weight_scale = [ + input_qparams[1].get_scale_per_tensor() + ] # pyre-ignore [61] output_qargs = get_output_qparams(node) post_conv2d_scale = [ (inp * w) / out for inp, w, out in zip( - [input_scale], [weight_scale], [output_qargs[0].scale] + itertools.cycle([input_scale]), + weight_scale, + itertools.cycle([output_qargs[0].get_scale_per_tensor()]), ) ] build_rescale( @@ -403,8 +394,8 @@ def define_node( input_node=conv2d_res, # type: ignore[possibly-undefined] output_name=output.name, output_type=output.dtype, - input_zp=0, - output_zp=output_qargs[0].zp, - per_channel=isinstance(weight_scale, torch.Tensor), + input_zp=[0], + output_zp=[output_qargs[0].get_zp_per_tensor()], + per_channel=per_channel_quant, rounding_mode=RoundingMode.SINGLE_ROUND, ) diff --git a/backends/arm/operators/op_index_tensor.py b/backends/arm/operators/op_index_tensor.py new file mode 100644 index 00000000000..8c5c84ddd5a --- /dev/null +++ b/backends/arm/operators/op_index_tensor.py @@ -0,0 +1,354 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import math +from typing import Any, List + +import executorch.backends.arm.tosa_utils as tutils + +import numpy as np + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_same_dtype, +) +from executorch.backends.arm.tosa_mapping import extract_tensor_meta, TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification +from torch.fx import Node + + +class CommonIndexTensorVisitor(NodeVisitor): + target = "aten.index.Tensor" + + def __init__(self, *args): + super().__init__(*args) + + def _get_tensor_info(self, tensor: Node): + """ + Consolidates obtaining name, dtype and shape into a common function + reconciling access based on the type of the input. + + Args: + fake_tensors (list[FakeTensor]): List of 2 or more FakeTensors + who's shapes to evaluate + + Returns: + tuple[bool, list[int]]: First element is whether the shapes are + broadcastable. Second element is the common shape if compatible. + If not, empty list. + + """ + if isinstance(tensor, Node): + dtype, shape, _ = extract_tensor_meta(tensor.meta, self.tosa_spec) + return tensor.name, dtype, shape + else: + return tensor.name, tensor.dtype, tensor.shape + + def _calculate_tosa_vals( + self, + index_shape: List[int], + index_nodes: List[Node], + values_shape: List[int], + ): + # From TOSA spec + # N - number of batches + # W - number of indices in each batch + # K - range of each index (number of elements to index through) + # C - number of data channels for each index + N, K, W, C = 1, 1, 1, 1 + + # Calculate K, W, C + # N - kept to 1 for generic n-dim implementation + # Note: If/when slice and ellipsis support is added batching + # may have to be used to facilitate proper implementation of + # the relevant logic. + # W - common between all indices as they have been broadcast + # to a common shape in a pass. + W = math.prod(index_shape) + + for i, dim in enumerate(values_shape): + if i < len(index_nodes): + K *= dim + + total_vals = math.prod(values_shape) + C = int(total_vals / K) + + return N, K, W, C + + def _calculate_value_strides(self, values_shape: List[int]) -> List[int]: + values_strides: List[int] = [] + stride = 1 + for dim in reversed(values_shape): + values_strides.insert(0, stride) + stride *= dim + + return values_strides + + +@register_node_visitor +class IndexTensorVisitor_080(CommonIndexTensorVisitor): + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+MI"), + TosaSpecification.create_from_string("TOSA-0.80+BI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + """ + This approach uses the fact that all indexing tensors are incremented + simultaneously and they essentially act as a map along the corresponding + dimensions of the values tensor. + Note: that this does not hold true when slicing or ellipsis ops + are involved as such they are not currently not supported. + + As such this approach flattens out the values tensor and + constructs a flattened out index obtained by flattening out the + index tensors, multiplying them by the relevant stride and accumulating them. + + This approach suffers from the fact that we are taking a number of index tensors of + type int32 and applying multiplications and additions. + + If the number of total elements in the values tensor exceeds int32 limits + then this approach falls apart. + """ + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + validate_same_dtype(self.target, [inputs[0], output]) + + values, indices = inputs + index_nodes = indices.special + + # Broadcast indices + broadcasted_tensors = tutils.broadcast_tensors( + tosa_graph, index_nodes, self.tosa_spec + ) + + values_strides = self._calculate_value_strides(values.shape) + + # The indices have already been broadcast to a common shape + # in so they are all the same. + _, index_dtype, index_shape = self._get_tensor_info(broadcasted_tensors[0]) + + N, K, W, C = self._calculate_tosa_vals(index_shape, index_nodes, values.shape) + + gather_idx_shape = [N, W] + + gather_index_name = "" + # Flatten out and shift indexes. + for i, index_node in enumerate(broadcasted_tensors): + index_name, _, _ = self._get_tensor_info(index_node) + index_name = index_node.name + + stride_shifted_indices = tosa_graph.addIntermediate( + index_shape, + index_dtype, + ) + + # Division by C is necessary when len(indices) < values.rank + # When there are dimensions left unindexed that changes the + # channels and thus the stride-shift. + data = np.full(index_shape, int(values_strides[i] / C)) + mul_const = tosa_graph.addConst(index_shape, index_dtype, data) + attr = ts.TosaSerializerAttribute() + attr.MulAttribute(shift=0) + tosa_graph.addOperator( + ts.TosaOp.Op().MUL, + [index_name, mul_const.name], + [stride_shifted_indices.name], + attr, + ) + + reshaped_idxs = tosa_graph.addIntermediate( + gather_idx_shape, + index_dtype, + ) + tutils.build_reshape( + tosa_graph, + stride_shifted_indices.name, + gather_idx_shape, + reshaped_idxs.name, + ) + + # Guarantees that the accumulation tensor is properly + # initialized and does not contain junk data. + if i == 0: + gather_index_name = reshaped_idxs.name + else: + tosa_graph.addOperator( + ts.TosaOp.Op().ADD, + [gather_index_name, reshaped_idxs.name], + [gather_index_name], + ) + + gather_vals_shape = [N, K, C] + reshaped_input = tosa_graph.addIntermediate(gather_vals_shape, values.dtype) + tutils.build_reshape( + tosa_graph, values.name, gather_vals_shape, reshaped_input.name + ) + + gather_out_shape = (N, W, C) + gather_out = tosa_graph.addIntermediate( + gather_out_shape, + output.dtype, + ) + tosa_graph.addOperator( + ts.TosaOp.Op().GATHER, + [reshaped_input.name, gather_index_name], + [gather_out.name], + None, + ) + + output_shape = tutils.tosa_shape(output.shape, output.dim_order) + tutils.build_reshape(tosa_graph, gather_out.name, output_shape, output.name) + + +@register_node_visitor +class IndexTensorVisitor(CommonIndexTensorVisitor): + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + """ + This approach uses the fact that all indexing tensors are incremented + simultaneously and they essentially act as a map along the corresponding + dimensions of the values tensor. + Note: that this does not hold true when slicing or ellipsis ops + are involved as such they are not currently not supported. + + As such this approach flattens out the values tensor and + constructs a flattened out index obtained by flattening out the + index tensors, multiplying them by the relevant stride and accumulating them. + + This approach suffers from the fact that we are taking a number of index tensors of + type int32 and applying multiplications and additions. + + If the number of total elements in the values tensor exceeds int32 limits + then this approach falls apart. + """ + import serializer.tosa_serializer as ts + + validate_same_dtype(self.target, [inputs[0], output]) + + values, indices = inputs + index_nodes = indices.special + + # Broadcast indices + broadcasted_tensors = tutils.broadcast_tensors( + tosa_graph, index_nodes, self.tosa_spec + ) + + # Calculate strides so we can shift indices down the line. + values_strides = self._calculate_value_strides(values.shape) + + # The indices have already been broadcast to a common shape + # in so they are all the same. + _, index_dtype, index_shape = self._get_tensor_info(broadcasted_tensors[0]) + + N, K, W, C = self._calculate_tosa_vals(index_shape, index_nodes, values.shape) + + gather_idx_shape = [N, W] + + gather_index_name = "" + # Flatten out and shift indexes. + for i, index_node in enumerate(broadcasted_tensors): + index_name, _, _ = self._get_tensor_info(index_node) + index_name = index_node.name + + stride_shifted_indices = tosa_graph.addIntermediate( + index_shape, + index_dtype, + ) + + # Division by C is necessary when len(indices) < values.rank + # When there are dimensions left unindexed that changes the + # channels and thus the stride-shift. + data = np.full(index_shape, int(values_strides[i] / C)) + mul_const = tosa_graph.addConst(index_shape, index_dtype, data) + tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_{i}_shift") + tosa_graph.addOperator( + ts.TosaOp.Op().MUL, + [index_name, mul_const.name, f"{node.name}_{i}_shift"], + [stride_shifted_indices.name], + ) + + reshaped_idxs = tosa_graph.addIntermediate( + gather_idx_shape, + index_dtype, + ) + tutils.build_reshape_tosa_1_0( + tosa_graph, + stride_shifted_indices.name, + gather_idx_shape, + reshaped_idxs.name, + shape_name_override=f"{node.name}_{i}_shape", + ) + + # Guarantees that the accumulation tensor is properly + # initialized and does not contain junk data. + if i == 0: + gather_index_name = reshaped_idxs.name + else: + tosa_graph.addOperator( + ts.TosaOp.Op().ADD, + [gather_index_name, reshaped_idxs.name], + [gather_index_name], + ) + + gather_vals_shape = [N, K, C] + reshaped_input = tosa_graph.addIntermediate(gather_vals_shape, values.dtype) + + tutils.build_reshape_tosa_1_0( + tosa_graph, + values.name, + gather_vals_shape, + reshaped_input.name, + shape_name_override=f"{node.name}_index_shape", + ) + + gather_out_shape = (N, W, C) + gather_out = tosa_graph.addIntermediate( + gather_out_shape, + output.dtype, + ) + tosa_graph.addOperator( + ts.TosaOp.Op().GATHER, + [reshaped_input.name, gather_index_name], + [gather_out.name], + None, + ) + + output_shape = tutils.tosa_shape(output.shape, output.dim_order) + + tutils.build_reshape_tosa_1_0( + tosa_graph, + gather_out.name, + list(output_shape), + output.name, + shape_name_override=f"{node.name}_output_shape", + ) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 32001671adb..b3c779477ca 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -47,7 +47,7 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [3, 4]) + validate_num_inputs(self.target, inputs, [3, 4, 5, 6]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, @@ -60,6 +60,10 @@ def define_node( kernel_size = inputs[1].special stride = inputs[2].special + if len(inputs) == 6: + ceil_mode = bool(inputs[5].number) + else: + ceil_mode = False try: pad_size_list = inputs[3].special pad_size_list = [ @@ -68,7 +72,7 @@ def define_node( pad_size_list[1], pad_size_list[1], ] - except IndexError: + except (IndexError, AttributeError): pad_size_list = [0, 0, 0, 0] # Adjust the padding as necessary @@ -77,12 +81,14 @@ def define_node( kernel_size[0], stride[0], pad_size_list[1], + ceil_mode, ) pad_size_list[3] = adjust_pooling_pad_if_needed( input_tensor.shape[3], kernel_size[1], stride[1], pad_size_list[3], + ceil_mode, ) accumulator_type = output.dtype @@ -91,12 +97,12 @@ def define_node( input_zp = 0 if inputs[0].dtype == ts.DType.INT8: input_qparams = get_input_qparams(node) - input_zp = input_qparams[0].zp + input_zp = input_qparams[0].get_zp_per_tensor() output_zp = 0 if output.dtype == ts.DType.INT8: output_qparams = get_output_qparams(node) - output_zp = output_qparams[0].zp + output_zp = output_qparams[0].get_zp_per_tensor() attr = ts.TosaSerializerAttribute() attr.PoolAttribute( @@ -138,7 +144,7 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [3, 4]) + validate_num_inputs(self.target, inputs, [3, 4, 5, 6]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, @@ -151,6 +157,11 @@ def define_node( kernel_size = inputs[1].special stride = inputs[2].special + if len(inputs) == 6: + ceil_mode = bool(inputs[5].number) + else: + ceil_mode = False + try: pad_size_list = inputs[3].special pad_size_list = [ @@ -159,7 +170,7 @@ def define_node( pad_size_list[1], pad_size_list[1], ] - except IndexError: + except (IndexError, AttributeError): pad_size_list = [0, 0, 0, 0] # Adjust the padding as necessary @@ -168,12 +179,14 @@ def define_node( kernel_size[0], stride[0], pad_size_list[1], + ceil_mode, ) pad_size_list[3] = adjust_pooling_pad_if_needed( input_tensor.shape[3], kernel_size[1], stride[1], pad_size_list[3], + ceil_mode, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 753bdfabacd..61f01cb7099 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -70,14 +70,14 @@ def define_node( input_A_rescaled = tqutils.build_rescale_to_int32( tosa_graph, input_A, - input_A_qargs.zp, - [1.0], + input_A_qargs.get_zp_per_tensor(), + 1.0, ) input_B_rescaled = tqutils.build_rescale_to_int32( tosa_graph, input_B, - input_B_qargs.zp, - [1.0], + input_B_qargs.get_zp_per_tensor(), + 1.0, ) output_shape = tutils.tosa_shape(output.shape, output.dim_order) @@ -101,7 +101,9 @@ def define_node( [mul_output.name], attr, ) - output_scale = input_A_qargs.scale * input_B_qargs.scale + output_scale = ( + input_A_qargs.get_scale_per_tensor() * input_B_qargs.get_scale_per_tensor() + ) tqutils.insert_rescale_op_to_int8(tosa_graph, mul_output, output_scale, node) @@ -174,15 +176,15 @@ def define_node( input_A_rescaled = tqutils.build_rescale_to_int32( tosa_graph, input_A, - input_A_qargs.zp, - [1.0], + input_A_qargs.get_zp_per_tensor(), + 1.0, tosa_spec=self.tosa_spec, ) input_B_rescaled = tqutils.build_rescale_to_int32( tosa_graph, input_B, - input_B_qargs.zp, - [1.0], + input_B_qargs.get_zp_per_tensor(), + 1.0, tosa_spec=self.tosa_spec, ) @@ -196,7 +198,9 @@ def define_node( [input_A_rescaled.name, input_B_rescaled.name, f"{node.name}_shift"], [mul_output.name], ) - output_scale = input_A_qargs.scale * input_B_qargs.scale + output_scale = ( + input_A_qargs.get_scale_per_tensor() * input_B_qargs.get_scale_per_tensor() + ) tqutils.insert_rescale_op_to_int8( tosa_graph, mul_output, output_scale, node, self.tosa_spec ) diff --git a/backends/arm/operators/op_neg.py b/backends/arm/operators/op_neg.py index 61b912669ff..e3b3eabf9ba 100644 --- a/backends/arm/operators/op_neg.py +++ b/backends/arm/operators/op_neg.py @@ -31,8 +31,8 @@ def get_negate_zero_points(node: torch.fx.Node, is_int8: bool) -> tuple[int, int """ if is_int8: return ( - get_input_qparams(node)[0].zp, - get_output_qparams(node)[0].zp, + get_input_qparams(node)[0].get_zp_per_tensor(), + get_output_qparams(node)[0].get_zp_per_tensor(), ) return (0, 0) diff --git a/backends/arm/operators/op_rescale.py b/backends/arm/operators/op_rescale.py index 1a5f91a81e6..df8d3c7dbef 100644 --- a/backends/arm/operators/op_rescale.py +++ b/backends/arm/operators/op_rescale.py @@ -119,8 +119,8 @@ def define_node( input_node=inputs[0], output_name=output.name, output_type=output.dtype, - input_zp=input_zp, - output_zp=output_zp, + input_zp=[input_zp], + output_zp=[output_zp], rounding_mode=RoundingMode.SINGLE_ROUND, per_channel=False, ) diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index f888bd8b72c..84a662db01c 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -106,8 +106,6 @@ def define_node( if inputs[0].dtype == ts.DType.INT8: return super().define_node(node, tosa_graph, inputs, output) - validate_num_inputs(self.target, inputs, 3) - tensor = inputs[0] input_shape = list(tensor.shape) dim = int(inputs[1].number % len(input_shape)) diff --git a/backends/arm/operators/op_upsample_bilinear2d.py b/backends/arm/operators/op_upsample_bilinear2d.py index 7e0d4a49556..c7edee9d882 100644 --- a/backends/arm/operators/op_upsample_bilinear2d.py +++ b/backends/arm/operators/op_upsample_bilinear2d.py @@ -110,8 +110,8 @@ def in_int16_range(x): input_node=intermediate, output_name=output.name, output_type=ts.DType.INT8, - input_zp=0, - output_zp=0, + input_zp=[0], + output_zp=[0], is_double_round=False, ) else: @@ -232,8 +232,8 @@ def in_int16_range(x): input_node=intermediate, output_name=output.name, output_type=ts.DType.INT8, - input_zp=0, - output_zp=0, + input_zp=[0], + output_zp=[0], rounding_mode=RoundingMode.SINGLE_ROUND, ) else: diff --git a/backends/arm/operators/operator_validation_utils.py b/backends/arm/operators/operator_validation_utils.py index 2dea9e2874b..fde76f31c7a 100644 --- a/backends/arm/operators/operator_validation_utils.py +++ b/backends/arm/operators/operator_validation_utils.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from math import ceil, floor from typing import Any, List, Optional from executorch.backends.arm.operators.node_visitor import NodeVisitor @@ -183,11 +184,18 @@ def validate_valid_dtype( def adjust_pooling_pad_if_needed( - input_size: int, kernel_size: int, stride: int, pad: int + input_size: int, kernel_size: int, stride: int, pad: int, ceil_mode: bool ) -> int: """ - Calculates the padding that needs to be removed to a pooling window to make it - divisible by the kernels stride. All inputs should correspond to the same dimension. + The Aten pooling ops has one value 'pad' per dimension to specify padding, but they + do not require input and output sizes to match up perfectly. Instead, the output + size is rounded up or down depending on ceil_mode, and padding at the end of the + input is automatically added or removed. TOSA on the other hand specifies two + padding values, one for pre-padding and one for post-padding, and these must satisfy + + output_size = (input_size + pre_pad + post_pad - kernel_size) / stride + 1 + + This function returns the post_pad value required to satisfy the above condition. Parameters: ----------- @@ -205,15 +213,16 @@ def adjust_pooling_pad_if_needed( Output: ------- - An int, representing the padding to remove to make the window divisible. + An int, giving the post-padding to use for the """ - if pad == 0: - return pad - mod_remainder = (input_size + 2 * pad - kernel_size) % stride + if ceil_mode: + output_size = ceil((input_size - kernel_size + 2 * pad) / stride) + 1 + else: + output_size = floor((input_size - kernel_size + 2 * pad) / stride) + 1 - # No need to adjust - if mod_remainder == 0: - return pad + # Solve for post_pad from + # output_size = (input_size + pre_pad + post_pad - kernel_size) / stride + 1 + adjusted_post_pad = (output_size - 1) * stride - input_size + kernel_size - pad - return pad - mod_remainder + return adjusted_post_pad diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py index 81a2946c8fb..9c0c15364fc 100644 --- a/backends/arm/operators/ops_binary.py +++ b/backends/arm/operators/ops_binary.py @@ -17,6 +17,7 @@ from executorch.backends.arm.operators.operator_validation_utils import ( validate_num_inputs, validate_same_dtype, + validate_valid_dtype, ) from executorch.backends.arm.tosa_mapping import TosaArg @@ -40,6 +41,30 @@ def define_node( validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) + if self.target in [ + "aten.bitwise_and.Tensor", + "aten.bitwise_xor.Tensor", + "aten.bitwise_or.Tensor", + "aten.bitwise_left_shift.Tensor", + ]: + validate_valid_dtype( + self.target, + [*inputs, output], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], + output.tosa_spec, + ) + if self.target in [ + "aten.logical_and.default", + "aten.logical_xor.defaul", + "aten.logical_or.default", + ]: + validate_valid_dtype( + self.target, + [*inputs, output], + [ts.DType.BOOL], + output.tosa_spec, + ) + tosa_graph.addOperator( tosa_op, [inputs[0].name, inputs[1].name], [output.name] ) @@ -66,6 +91,30 @@ def define_node( validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs, output], ts) + if self.target in [ + "aten.bitwise_and.Tensor", + "aten.bitwise_xor.Tensor", + "aten.bitwise_or.Tensor", + "aten.bitwise_left_shift.Tensor", + ]: + validate_valid_dtype( + self.target, + [*inputs, output], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], + output.tosa_spec, + ) + if self.target in [ + "aten.logical_and.default", + "aten.logical_xor.defaul", + "aten.logical_or.default", + ]: + validate_valid_dtype( + self.target, + [*inputs, output], + [ts.DType.BOOL], + output.tosa_spec, + ) + tosa_graph.addOperator( tosa_op, [inputs[0].name, inputs[1].name], [output.name] ) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index c20065654ca..94e2ae74a7a 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -38,7 +38,6 @@ HistogramObserver, MinMaxObserver, MovingAverageMinMaxObserver, - MovingAveragePerChannelMinMaxObserver, ObserverOrFakeQuantizeConstructor, PerChannelMinMaxObserver, PlaceholderObserver, @@ -95,24 +94,26 @@ def get_symmetric_quantization_config( **extra_args, ), ) + + # Setup quantization config for weights weight_qscheme = ( torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric ) weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = ( MinMaxObserver ) + # Determine the right observer/fake-quant constructor if is_qat: - # TODO: qat + per channel? - weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize - elif is_per_channel: - weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + # Set plain fake-quant with true min/max + weight_observer_or_fake_quant_ctr = FakeQuantize + else: + # PTQ: set min/max observer + weight_observer_or_fake_quant_ctr = ( + PerChannelMinMaxObserver if is_per_channel else MinMaxObserver + ) + + extra_args = {"eps": 2**-12} - extra_args: Dict[str, Any] = {"eps": 2**-12} - if is_qat: - if weight_qscheme == torch.per_tensor_symmetric: - extra_args["observer"] = MovingAverageMinMaxObserver - else: - extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] weight_quantization_spec = QuantizationSpec( dtype=torch.int8, quant_min=weight_qmin, diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 5c2f7822097..83a648c7d8a 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -260,6 +260,7 @@ def _match_pattern( torch.ops.aten.clamp.Tensor, torch.ops.aten.unflatten.int, torch.ops.aten.index_select.default, + torch.ops.aten.index.Tensor, ] _one_to_one_shared_input_or_input_act_qspec = [ diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index 747d945d30c..8f31f019332 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -78,11 +78,11 @@ def _derive_qparams_fn( ) act_obs_or_fq = obs_or_fqs[0] weight_obs_or_fq = obs_or_fqs[1] - act_scale, act_zp = act_obs_or_fq.calculate_qparams() - weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() - return torch.tensor([act_scale * weight_scale]).to( + act_scale, _ = act_obs_or_fq.calculate_qparams() + weight_scale, _ = weight_obs_or_fq.calculate_qparams() + return torch.tensor(act_scale * weight_scale).to( torch.float32 - ), torch.tensor([0]).to(torch.int32) + ), torch.full_like(weight_scale, fill_value=0, dtype=torch.int32) if node.target in [ torch.ops.aten.conv1d.default, @@ -92,13 +92,25 @@ def _derive_qparams_fn( ]: input_act = node.args[0] weight = node.args[1] + # If the weights are quantized per_tensor, do the same with bias + qscheme = ( + torch.per_tensor_symmetric + if self.weight is None + else self.weight.qscheme + ) + ch_axis = None + if self.weight is not None: + if qscheme == torch.per_channel_symmetric: + ch_axis = self.weight.ch_axis + quantization_spec = DerivedQuantizationSpec( derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item] derive_qparams_fn=_derive_qparams_fn, dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max - 1, - qscheme=torch.per_tensor_symmetric, + qscheme=qscheme, + ch_axis=ch_axis, ) return quantization_spec # type: ignore[return-value] diff --git a/backends/arm/runtime/EthosUBackend.cpp b/backends/arm/runtime/EthosUBackend.cpp index f5e4af860d6..d29c32b02f3 100644 --- a/backends/arm/runtime/EthosUBackend.cpp +++ b/backends/arm/runtime/EthosUBackend.cpp @@ -113,7 +113,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { BackendInitContext& context, FreeableBuffer* processed, ArrayRef compile_specs) const override { - ET_LOG(Info, "EthosUBackend::init %p", processed->data()); + ET_LOG(Info, "data:%p", processed->data()); const char* data = static_cast(processed->data()); size_t size = processed->size(); @@ -173,7 +173,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { static_cast(execution_handle->processed->data()); EXECUTORCH_PROF_END(event_tracer, event_tracer_local_scope); - ET_LOG(Debug, "EthosUBackend::execute %p", data); + ET_LOG(Debug, "data:%p", data); EXECUTORCH_PROF_START( event_tracer, @@ -182,7 +182,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { // Read key sections from the vela_bin_stream if (vela_bin_read(data, &handles, execution_handle->processed->size()) == false) { - ET_LOG(Error, "EthosUBackend::vela_read: error, invalid binary layout"); + ET_LOG(Error, "vela_read: error, invalid binary layout"); return Error::InvalidProgram; } EXECUTORCH_PROF_END(event_tracer, event_tracer_local_scope); @@ -193,9 +193,16 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { // the end of the execution of the Ethos-U custom delegate char* ethosu_scratch = static_cast(temp_allocator->allocate(handles.scratch_data_size)); + if (ethosu_scratch == nullptr) { + ET_LOG( + Error, + "Failed to allocate scratch buffer of %zu bytes from temp_allocator", + handles.scratch_data_size); + return Error::MemoryAllocationFailed; + } ET_LOG( Debug, - "EthosUBackend::execute: Running program data:\n cmd %p %zu\n weight %p %zu\n scratch %p %zu\n fast scratch %p %zu\n", + "Running program data:\n cmd %p %zu\n weight %p %zu\n scratch %p %zu\n fast scratch %p %zu\n", handles.cmd_data, handles.cmd_data_size, handles.weight_data, @@ -227,12 +234,17 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { supported |= (tensor_in.scalar_type() == ScalarType::Short and handles.inputs->io[i].elem_size == 2); + // bool (IOQDQ pass prepared networks) + supported |= + (tensor_in.scalar_type() == ScalarType::Bool and + handles.inputs->io[i].elem_size == 1); if (!supported) { ET_LOG( Error, - "Input %d expected Integer (4 byte) or Char (1 byte) integer inputs, got ScalarType id %s", + "Input %d expected Integer (4 byte), Char (1 byte) or Bool (1 byte) integer inputs, got ScalarType id %s size %d", i, - executorch::runtime::toString(tensor_in.scalar_type())); + executorch::runtime::toString(tensor_in.scalar_type()), + handles.inputs->io[i].elem_size); return Error::InvalidProgram; } supported = executorch::runtime::is_contiguous_dim_order( @@ -250,15 +262,17 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { bool permuted_input_shape; ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute( i, tensor_in, &handles.inputs->io[i], &permuted_input_shape)); - bool both_char = tensor_in.scalar_type() == ScalarType::Char and - handles.inputs->io[i].elem_size == 1; - bool both_int = tensor_in.scalar_type() == ScalarType::Int and + bool both_int = tensor_in.scalar_type() == ScalarType::Int && handles.inputs->io[i].elem_size == 4; - bool both_short = tensor_in.scalar_type() == ScalarType::Short and + bool both_char = tensor_in.scalar_type() == ScalarType::Char && + handles.inputs->io[i].elem_size == 1; + bool both_short = tensor_in.scalar_type() == ScalarType::Short && handles.inputs->io[i].elem_size == 2; + bool both_bool = tensor_in.scalar_type() == ScalarType::Bool && + (handles.inputs->io[i].elem_size == 1); // Select a compatible copy routine - if (both_char && permuted_input_shape) { + if ((both_char || both_bool) && permuted_input_shape) { EXECUTORCH_PROF_SCOPE( event_tracer, "+EthosUBackend::execute()handles.input.permute_CHW_to_HWC()"); @@ -269,7 +283,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { tensor_in.size(1), tensor_in.size(2), tensor_in.size(3)); - } else if (both_char || both_int || both_short) { + } else if (both_char || both_int || both_short || both_bool) { EXECUTORCH_PROF_SCOPE( event_tracer, "+EthosUBackend::execute()handles.input.memcpy()"); // Sizes match and elt size matches so memcpy @@ -301,7 +315,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { std::unique_ptr( ethosu_reserve_driver(), ethosu_release_driver); if (driver == NULL) { - ET_LOG(Error, "EthosUBackend::execute: ethosu_reserve_driver failed"); + ET_LOG(Error, "ethosu_reserve_driver failed"); return Error::InvalidState; } @@ -333,10 +347,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { EXECUTORCH_PROF_END(event_tracer, event_tracer_local_scope); if (result != 0) { - ET_LOG( - Error, - "EthosUBackend::execute: Ethos-U invocation failed error (%d)", - result); + ET_LOG(Error, "Ethos-U invocation failed error (%d)", result); return Error::InvalidProgram; } int tensor_dim = 0, io_dim = 0; @@ -359,7 +370,9 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { bool permuted_output_shape; ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute( i, tensor_out, &handles.outputs->io[i], &permuted_output_shape)); - if (tensor_out.scalar_type() == ScalarType::Char && + + if ((tensor_out.scalar_type() == ScalarType::Char || + tensor_out.scalar_type() == ScalarType::Bool) && permuted_output_shape) { EXECUTORCH_PROF_SCOPE( event_tracer, @@ -375,17 +388,12 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { tensor_out.size(3)); } else { EXECUTORCH_PROF_SCOPE( - event_tracer, "+EthosUBackend::execute()handles.output.move()"); - for (int j = 0; j < tensor_out.numel(); j++) { - if (tensor_out.scalar_type() == ScalarType::Char) { - const char* output_address = static_cast(output_addr); - tensor_out.mutable_data_ptr()[j] = output_address[j]; - } else { - const int* output_address = - reinterpret_cast(output_addr); - tensor_out.mutable_data_ptr()[j] = output_address[j]; - } - } + event_tracer, "+EthosUBackend::execute()handles.output.memcpy()"); + + memcpy( + tensor_out.mutable_data_ptr(), + static_cast(output_addr), + tensor_out.nbytes()); } } if (tensor_dim != io_dim) { diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index 189aa7f8e59..f23191b55b0 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -13,6 +13,7 @@ "hardswish.default", "linear.default", "maximum.default", + "multihead_attention.default", "adaptive_avg_pool2d.default", "bitwise_right_shift.Tensor", "bitwise_left_shift.Tensor", diff --git a/backends/arm/test/misc/test_lifted_tensor.py b/backends/arm/test/misc/test_lifted_tensor.py index 7f1a9938037..c17d93765e5 100644 --- a/backends/arm/test/misc/test_lifted_tensor.py +++ b/backends/arm/test/misc/test_lifted_tensor.py @@ -12,7 +12,7 @@ TosaPipelineBI, TosaPipelineMI, ) -from executorch.backends.xnnpack.test.tester import ToEdge +from executorch.backends.test.harness.stages import StageType input_t1 = Tuple[torch.Tensor] @@ -72,9 +72,8 @@ def test_partition_lifted_tensor_tosa_MI(test_data: input_t1): use_to_edge_transform_and_lower=False, ) pipeline.run() - to_edge_stage_name = pipeline.tester.stage_name(ToEdge) signature = ( - pipeline.tester.stages[to_edge_stage_name] + pipeline.tester.stages[StageType.TO_EDGE] .artifact.exported_program() .graph_signature ) @@ -94,9 +93,8 @@ def test_partition_lifted_tensor_tosa_BI(test_data: input_t1): use_to_edge_transform_and_lower=False, ) pipeline.run() - to_edge_stage_name = pipeline.tester.stage_name(ToEdge) signature = ( - pipeline.tester.stages[to_edge_stage_name] + pipeline.tester.stages[StageType.TO_EDGE] .artifact.exported_program() .graph_signature ) diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index c83a70e001e..ac513530e04 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -46,6 +46,7 @@ def test_mv2_tosa_BI(): aten_op=[], exir_op=[], use_to_edge_transform_and_lower=True, + per_channel_quantization=True, atol=0.25, qtol=1, ) @@ -62,6 +63,7 @@ def test_mv2_u55_BI(): exir_ops=[], run_on_fvp=True, use_to_edge_transform_and_lower=True, + per_channel_quantization=True, atol=0.25, qtol=1, ) @@ -78,6 +80,7 @@ def test_mv2_u85_BI(): exir_ops=[], run_on_fvp=True, use_to_edge_transform_and_lower=True, + per_channel_quantization=True, atol=0.25, qtol=1, ) diff --git a/backends/arm/test/models/test_torch_functions.py b/backends/arm/test/models/test_torch_functions.py index 5cd4bd3aaed..c7fc1654caa 100644 --- a/backends/arm/test/models/test_torch_functions.py +++ b/backends/arm/test/models/test_torch_functions.py @@ -130,7 +130,6 @@ def test_torch_fns_MI(test_data): "topk": "NotImplementedError: No registered serialization name for found", "sort": "NotImplementedError: No registered serialization name for found", "t": "MLETORCH-855: Issue with Quantization folding.", - "norm": "An error occurred when running the 'KeepDimsFalseToSqueezePass' pass after the following passes:", }, strict=False, ) diff --git a/backends/arm/test/ops/test_any.py b/backends/arm/test/ops/test_any.py index 6ddef1ad0b5..338c5f05cc6 100644 --- a/backends/arm/test/ops/test_any.py +++ b/backends/arm/test/ops/test_any.py @@ -6,7 +6,6 @@ from typing import List, Tuple -import pytest import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -125,14 +124,30 @@ def forward(self, x: torch.Tensor): @common.parametrize("test_data", test_data) def test_any_tosa_MI(test_data: input_t1): op, test_input = test_data() - pipeline = TosaPipelineMI[input_t1](op, test_input(), op.aten_op, op.exir_op) + pipeline = TosaPipelineMI[input_t1]( + op, + test_input(), + op.aten_op, + op.exir_op, + atol=0, + rtol=0, + qtol=0, + ) pipeline.run() @common.parametrize("test_data", test_data) def test_any_tosa_BI(test_data: input_t1): op, test_input = test_data() - pipeline = TosaPipelineBI[input_t1](op, test_input(), op.aten_op, op.exir_op) + pipeline = TosaPipelineBI[input_t1]( + op, + test_input(), + op.aten_op, + op.exir_op, + atol=0, + rtol=0, + qtol=0, + ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -153,7 +168,6 @@ def test_any_u55_BI(test_data: input_t1): @common.parametrize("test_data", test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_any_u85_BI(test_data: input_t1): op, test_input = test_data() @@ -163,6 +177,9 @@ def test_any_u85_BI(test_data: input_t1): op.aten_op, op.exir_op, run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") diff --git a/backends/arm/test/ops/test_avg_pool2d.py b/backends/arm/test/ops/test_avg_pool2d.py index 66d56ce584c..e2bbfc3a8cd 100644 --- a/backends/arm/test/ops/test_avg_pool2d.py +++ b/backends/arm/test/ops/test_avg_pool2d.py @@ -10,7 +10,7 @@ import torch -from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineBI, @@ -26,27 +26,19 @@ input_t = Tuple[torch.Tensor] -class AvgPool2d(torch.nn.Module): - def __init__( - self, - kernel_size: int | Tuple[int, int], - stride: int | Tuple[int, int], - padding: int | Tuple[int, int], - ): - super().__init__() - self.avg_pool_2d = torch.nn.AvgPool2d( - kernel_size=kernel_size, stride=stride, padding=padding - ) - - def forward(self, x): - return self.avg_pool_2d(x) +class AvgPool2d(torch.nn.modules.AvgPool2d): + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) test_modules = { - "zeros": lambda: (AvgPool2d(4, 2, 0), (torch.zeros(1, 16, 50, 32),)), - "ones": lambda: (AvgPool2d(4, 2, 0), (torch.ones(1, 16, 50, 32),)), - "rand": lambda: (AvgPool2d(4, 2, 0), (torch.rand(1, 16, 50, 32),)), - "randn": lambda: (AvgPool2d(4, 2, 0), (torch.randn(1, 16, 50, 32),)), + "zeros": lambda: (AvgPool2d(4, 2, 0, False), (torch.zeros(1, 16, 50, 32),)), + "ones": lambda: (AvgPool2d(4, 2, 0, False, True), (torch.ones(1, 16, 50, 32),)), + "rand": lambda: (AvgPool2d(4, 2, 0, False, True, 16), (torch.rand(1, 16, 50, 32),)), + "randn": lambda: ( + AvgPool2d(4, 2, 0, divisor_override=16), + (torch.randn(1, 16, 50, 32),), + ), "kernel_3x3_stride_1_pad_1": lambda: ( AvgPool2d((3, 3), (1, 1), 1), (torch.rand(1, 16, 50, 32),), @@ -60,7 +52,7 @@ def forward(self, x): (torch.rand(1, 16, 50, 32),), ), "non_divisible_window": lambda: ( - AvgPool2d(3, 2, 1), + AvgPool2d(3, 2, 1, count_include_pad=False), (torch.rand(1, 16, 112, 112),), ), "non_divisible_window_height": lambda: ( @@ -68,9 +60,37 @@ def forward(self, x): (torch.rand(1, 16, 56, 56),), ), "non_divisible_window_width": lambda: ( - AvgPool2d(3, (1, 2), 1), + AvgPool2d(3, (1, 2), 1, count_include_pad=False), (torch.rand(1, 16, 56, 56),), ), + "non_divisible_window_ceil_mode": lambda: ( + AvgPool2d(3, 2, 1, True), + (torch.rand(1, 16, 112, 112),), + ), + "non_divisible_window_height_ceil_mode": lambda: ( + AvgPool2d(3, (2, 1), 1, True, False), + (torch.rand(1, 1, 14, 14),), + ), + "non_divisible_window_width_ceil_mode": lambda: ( + AvgPool2d(3, (1, 2), 1, True, True), + (torch.rand(1, 1, 14, 14),), + ), + "divisor_override": lambda: ( + AvgPool2d(3, 2, 1, False, False, divisor_override=2), + (torch.rand(1, 1, 14, 14),), + ), + "divisor_override_count_include_pad": lambda: ( + AvgPool2d(3, 2, 1, False, True, divisor_override=2), + (torch.rand(1, 1, 14, 14),), + ), + "divisor_override_ceil_mode": lambda: ( + AvgPool2d(3, 2, 1, True, False, divisor_override=2), + (torch.rand(1, 1, 14, 14),), + ), + "divisor_override_ceil_mode_count_include_pad": lambda: ( + AvgPool2d(3, 2, 1, True, True, divisor_override=2), + (torch.rand(1, 1, 14, 14),), + ), } @@ -83,11 +103,8 @@ def test_avg_pool2d_tosa_MI(test_module): input_tensor, aten_op, exir_op, - run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"), ) - if conftest.is_option_enabled("tosa_ref_model"): - pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1) - pipeline.run() + pipeline.run() @common.parametrize("test_module", test_modules) @@ -99,11 +116,8 @@ def test_avg_pool2d_tosa_BI(test_module): input_tensor, aten_op, exir_op, - run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"), ) - if conftest.is_option_enabled("tosa_ref_model"): - pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1) - pipeline.run() + pipeline.run() @common.parametrize("test_module", test_modules) @@ -118,7 +132,6 @@ def test_avg_pool2d_u55_BI(test_module): exir_op, run_on_fvp=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1) pipeline.run() @@ -134,27 +147,25 @@ def test_avg_pool2d_u85_BI(test_module): exir_op, run_on_fvp=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1) - pipeline.run() reject_modules = { "kernel_1x1_stride_1_pad_0": lambda: (AvgPool2d(1, 1, 0), torch.rand(2, 5, 5, 5)), "kernel_2x9_stride_1_pad_1": lambda: ( - AvgPool2d((2, 9), 1, 1), + AvgPool2d((2, 9), 1, 1, count_include_pad=False), torch.rand(1, 16, 5, 32), ), "kernel_1x4_stride_0_pad_0": lambda: ( - AvgPool2d(1, 4, 0), + AvgPool2d(1, 4, 0, count_include_pad=False), torch.rand(1, 10, 10, 10), ), "kernel_1x257_stride_1_pad_0_large": lambda: ( - AvgPool2d((1, 257), 1, 0), + AvgPool2d((1, 257), 1, 0, count_include_pad=False), torch.rand(1, 16, 5, 300), ), "kernel_800x90_stride_1_pad_0_extreme": lambda: ( - AvgPool2d((800, 90), 1, 0), + AvgPool2d((800, 90), 1, 0, count_include_pad=False), torch.rand(1, 16, 850, 100), ), } diff --git a/backends/arm/test/ops/test_bitwise.py b/backends/arm/test/ops/test_bitwise.py index 8be8ba35b4e..032639b8607 100644 --- a/backends/arm/test/ops/test_bitwise.py +++ b/backends/arm/test/ops/test_bitwise.py @@ -6,7 +6,6 @@ from typing import Tuple -import pytest import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -30,6 +29,22 @@ class BitwiseBinary(torch.nn.Module): torch.ones(10, 10, 10, dtype=torch.int8), torch.ones(10, 10, 10, dtype=torch.int8), ), + "pattern_int8": lambda: ( + 0xAA * torch.ones(1, 2, 2, 2, dtype=torch.int8), + 0xCC * torch.ones(1, 2, 2, 2, dtype=torch.int8), + ), + "pattern_int16": lambda: ( + 0xAAAA * torch.ones(1, 2, 2, 2, dtype=torch.int16), + 0xCCCC * torch.ones(1, 2, 2, 2, dtype=torch.int16), + ), + "pattern_int32": lambda: ( + 0xAAAAAAAA * torch.ones(1, 2, 2, 2, dtype=torch.int32), + 0xCCCCCCCC * torch.ones(1, 2, 2, 2, dtype=torch.int32), + ), + "pattern_bool": lambda: ( + torch.tensor([True, False, True], dtype=torch.bool), + torch.tensor([True, True, False], dtype=torch.bool), + ), "rand_rank2": lambda: ( torch.randint(-128, 127, (10, 10), dtype=torch.int8), torch.randint(-128, 127, (10, 10), dtype=torch.int8), @@ -68,7 +83,13 @@ def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor): @common.parametrize("test_data", And().test_data) def test_bitwise_and_tensor_tosa_MI(test_data: input_t2): pipeline = TosaPipelineMI[input_t2]( - And(), test_data(), And().aten_op, And().exir_op + And(), + test_data(), + And().aten_op, + And().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.run() @@ -76,7 +97,13 @@ def test_bitwise_and_tensor_tosa_MI(test_data: input_t2): @common.parametrize("test_data", And().test_data) def test_bitwise_and_tensor_tosa_BI(test_data: input_t2): pipeline = TosaPipelineBI[input_t2]( - And(), test_data(), And().aten_op, And().exir_op + And(), + test_data(), + And().aten_op, + And().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -97,11 +124,17 @@ def test_bitwise_and_tensor_u55_BI(test_data: input_t2): @common.parametrize("test_data", And().test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_bitwise_and_tensor_u85_BI(test_data: input_t2): pipeline = EthosU85PipelineBI[input_t2]( - And(), test_data(), And().aten_op, And().exir_op, run_on_fvp=True + And(), + test_data(), + And().aten_op, + And().exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -111,7 +144,13 @@ def test_bitwise_and_tensor_u85_BI(test_data: input_t2): @common.parametrize("test_data", Xor().test_data) def test_bitwise_xor_tensor_tosa_MI(test_data: input_t2): pipeline = TosaPipelineMI[input_t2]( - Xor(), test_data(), Xor().aten_op, Xor().exir_op + Xor(), + test_data(), + Xor().aten_op, + Xor().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.run() @@ -119,7 +158,13 @@ def test_bitwise_xor_tensor_tosa_MI(test_data: input_t2): @common.parametrize("test_data", Xor().test_data) def test_bitwise_xor_tensor_tosa_BI(test_data: input_t2): pipeline = TosaPipelineBI[input_t2]( - Xor(), test_data(), Xor().aten_op, Xor().exir_op + Xor(), + test_data(), + Xor().aten_op, + Xor().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -140,11 +185,17 @@ def test_bitwise_xor_tensor_u55_BI(test_data: input_t2): @common.parametrize("test_data", Xor().test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_bitwise_xor_tensor_u85_BI(test_data: input_t2): pipeline = EthosU85PipelineBI[input_t2]( - Xor(), test_data(), Xor().aten_op, Xor().exir_op, run_on_fvp=True + Xor(), + test_data(), + Xor().aten_op, + Xor().exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -153,13 +204,29 @@ def test_bitwise_xor_tensor_u85_BI(test_data: input_t2): @common.parametrize("test_data", Or().test_data) def test_bitwise_or_tensor_tosa_MI(test_data: input_t2): - pipeline = TosaPipelineMI[input_t2](Or(), test_data(), Or().aten_op, Or().exir_op) + pipeline = TosaPipelineMI[input_t2]( + Or(), + test_data(), + Or().aten_op, + Or().exir_op, + atol=0, + rtol=0, + qtol=0, + ) pipeline.run() @common.parametrize("test_data", Or().test_data) def test_bitwise_or_tensor_tosa_BI(test_data: input_t2): - pipeline = TosaPipelineBI[input_t2](Or(), test_data(), Or().aten_op, Or().exir_op) + pipeline = TosaPipelineBI[input_t2]( + Or(), + test_data(), + Or().aten_op, + Or().exir_op, + atol=0, + rtol=0, + qtol=0, + ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -179,7 +246,6 @@ def test_bitwise_or_tensor_u55_BI(test_data: input_t2): @common.parametrize("test_data", Or().test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_bitwise_or_tensor_u85_BI(test_data: input_t2): pipeline = EthosU85PipelineBI[input_t2]( @@ -188,6 +254,9 @@ def test_bitwise_or_tensor_u85_BI(test_data: input_t2): Or().aten_op, Or().exir_op, run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index 158c296e4ec..658978d0de8 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -327,6 +327,34 @@ def forward(self, x): batches=1, ) +conv2d_groups = Conv2d( + in_channels=12, + out_channels=9, + kernel_size=(3, 3), + stride=1, + padding=0, + dilation=1, + width=7, + height=7, + batches=1, + groups=3, + bias=False, +) + +conv2d_groups_bias = Conv2d( + in_channels=15, + out_channels=5, + kernel_size=(3, 3), + stride=1, + padding=0, + dilation=1, + width=7, + height=7, + batches=1, + groups=5, + bias=True, +) + # Shenanigan to get a nicer output when test fails. With unittest it looks like: # FAIL: test_convolution_2d_tosa_BI_2_3x3_1x3x12x12_st2_pd1 test_modules = { @@ -348,6 +376,8 @@ def forward(self, x): "3x3_1x3x224x224_st2_pd1": lambda: conv2d_3x3_1x3x224x224_st2_pd1, "two_conv2d_nobias": lambda: two_conv2d_nobias, "two_conv2d": lambda: two_conv2d, + "groups": lambda: conv2d_groups, + "groups_bias": lambda: conv2d_groups_bias, } fvp_xfails = { diff --git a/backends/arm/test/ops/test_index_tensor.py b/backends/arm/test/ops/test_index_tensor.py new file mode 100644 index 00000000000..f1f6f5171d8 --- /dev/null +++ b/backends/arm/test/ops/test_index_tensor.py @@ -0,0 +1,462 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from enum import IntEnum +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineBI, + TosaPipelineMI, +) + + +class IndexTensorTestCommon: + """Class containing constants common between the tests""" + + aten_op = "torch.ops.aten.index.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_index_Tensor" + + # Gathers and reshapes should result in no inaccuracies + rtol = 0.0 + atol = 0.0 + + class OpPlacement(IntEnum): + """ + Simple enum used to indicate where slices or ellipsis should be placed + in tests. + IntEnum so that Dynamo does not complain about unsupported types. + """ + + BEFORE = 1 + MIDDLE = 2 + AFTER = 3 + + +input_params_slice = Tuple[ + torch.Tensor, int, int, IndexTensorTestCommon.OpPlacement, Tuple[torch.Tensor] +] +input_params = Tuple[torch.Tensor, Tuple[torch.Tensor]] + + +class IndexTensor_Ellipsis(torch.nn.Module): + """ + There are technical limitations with torch/export as it does not support + the ellipsis class and as such the forward function has been crafted + to circumvent that limitation. + """ + + # xfail - ellipsis unsupported + test_data_ellipsis: dict[input_params] = { + "test_4d_ellipsis_before": ( + torch.rand(size=(25, 5, 13, 7)), + IndexTensorTestCommon.OpPlacement.BEFORE, + (torch.arange(2, dtype=torch.int32),), + ), + "test_4d_ellipsis_middle": ( + torch.rand(size=(25, 5, 13, 7)), + IndexTensorTestCommon.OpPlacement.MIDDLE, + ( + torch.arange(2, dtype=torch.int32), + torch.arange(2, dtype=torch.int32), + ), + ), + "test_4d_ellipsis_after": ( + # Due to the information passed to the NodeVisitor and + # preceding passes, detecting this and rejecting it for + # partitioning is difficult and unreliable, as such + # it is not xfail as the existing logic can handle it. + torch.rand(size=(25, 5, 13, 7)), + IndexTensorTestCommon.OpPlacement.AFTER, + (torch.arange(2, dtype=torch.int32),), + ), + } + + def forward( + self, + input_: torch.Tensor, + position: IndexTensorTestCommon.OpPlacement, + indices: Tuple[None | torch.Tensor], + ): + match position: + case IndexTensorTestCommon.OpPlacement.BEFORE: + return input_[..., indices[0]] + case IndexTensorTestCommon.OpPlacement.MIDDLE: + return input_[indices[0], ..., indices[1]] + case IndexTensorTestCommon.OpPlacement.AFTER: + return input_[indices[0], ...] + + return input_[indices] + + +@common.parametrize( + "test_data", + IndexTensor_Ellipsis.test_data_ellipsis, + xfails={ + # More info in index_tensor_support.py + "test_4d_ellipsis_before": "Ellipsis before index unsupported", + "test_4d_ellipsis_middle": "Ellipsis before index unsupported", + }, +) +def test_index_tensor_tosa_MI_ellipsis(test_data: input_params): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineMI[input_params]( + IndexTensor_Ellipsis(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + atol=IndexTensorTestCommon.atol, + rtol=IndexTensorTestCommon.rtol, + ).run() + ) + + +@common.parametrize( + "test_data", + IndexTensor_Ellipsis.test_data_ellipsis, + xfails={ + # More info in index_tensor_support.py + "test_4d_ellipsis_before": "Ellipsis before index unsupported", + "test_4d_ellipsis_middle": "Ellipsis before index unsupported", + }, +) +def test_index_tensor_tosa_BI_ellipsis(test_data: input_params): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineBI[input_params]( + IndexTensor_Ellipsis(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + ).run() + ) + + +class IndexTensor_Slice(torch.nn.Module): + """ + There are technical limitations with Dynamo as it does not support the + slice class and as such the forward function has been crafted + to circumvent that limitation. + """ + + # xfail - None unsupported + test_data: dict[input_params_slice] = { + "test_4d_slice_before_1d_idx": ( + # Value tens is 3D because with the + torch.rand(size=(5, 3, 4, 5)), + 0, + 2, + IndexTensorTestCommon.OpPlacement.BEFORE, + (torch.arange(2, dtype=torch.int32),), + ), + "test_3d_slice_before_2d_idx": ( + # TODO: MLETORCH-859 - Testing framework does not support output rank > 4 + # With the bellow configuration a 4D value tensor and 2D index tensor + # results in a 5D output. + torch.arange(5 * 3 * 4, dtype=torch.float32).reshape(5, 3, 4), + 0, + 2, + IndexTensorTestCommon.OpPlacement.BEFORE, + (torch.arange(2, dtype=torch.int32).unsqueeze(0).tile(2, 1),), + ), + "test_4d_slice_middle": ( + torch.arange(5 * 3 * 2, dtype=torch.int32).reshape(5, 3, 2), + 0, + 2, + IndexTensorTestCommon.OpPlacement.MIDDLE, + ( + torch.arange(2, dtype=torch.int32), + torch.arange(2, dtype=torch.int32), + ), + ), + "test_4d_slice_after": ( + # Due to the information passed to the NodeVisitor and + # preceding passes, detecting this and rejecting it for + # partitioning is difficult and unreliable, as such + # it is not xfail as the existing logic can handle it. + torch.rand(size=(25, 5, 13, 7)), + 0, + 2, + IndexTensorTestCommon.OpPlacement.AFTER, + (torch.arange(2, dtype=torch.int32),), + ), + } + + def forward( + self, + input_: torch.Tensor, + slice_start: int, + slice_end: int, + position: IndexTensorTestCommon.OpPlacement, + indices: Tuple[None | torch.Tensor], + ): + match position: + case IndexTensorTestCommon.OpPlacement.BEFORE: + return input_[slice_start:slice_end, indices[0]] + case IndexTensorTestCommon.OpPlacement.MIDDLE: + return input_[indices[0], slice_start:slice_end, indices[1]] + case IndexTensorTestCommon.OpPlacement.AFTER: + return input_[indices[0], slice_start:slice_end] + + +@common.parametrize( + "test_data", + IndexTensor_Slice.test_data, + xfails={ + # More info in index_tensor_support.py + "test_4d_slice_before_1d_idx": "Slice before index unsupported", + "test_3d_slice_before_2d_idx": "Slice before index unsupported", + "test_4d_slice_middle": "Slice before index unsupported", + }, +) +def test_index_tensor_tosa_MI_slice(test_data: input_params_slice): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineMI[input_params_slice]( + IndexTensor_Slice(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + atol=IndexTensorTestCommon.atol, + rtol=IndexTensorTestCommon.rtol, + ).run() + ) + + +@common.parametrize( + "test_data", + IndexTensor_Slice.test_data, + xfails={ + # More info in index_tensor_support.py + "test_4d_slice_before_1d_idx": "Slice before index unsupported", + "test_3d_slice_before_2d_idx": "Slice before index unsupported", + "test_4d_slice_middle": "Slice before index unsupported", + }, +) +def test_index_tensor_tosa_BI_slice(test_data: input_params_slice): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineBI[input_params_slice]( + IndexTensor_Slice(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + ).run() + ) + + +class IndexTensor(torch.nn.Module): + test_data: dict[input_params] = { + "test_2d_1_idx": (torch.rand(5, 2), (torch.arange(5, dtype=torch.int32),)), + "test_2d_1_less_than_max_idx": ( + torch.rand(5, 2), + (torch.arange(3, dtype=torch.int32),), + ), + "test_2d_1_2d_idx": ( + torch.rand(5, 2), + (torch.randint(5, size=(4, 3), dtype=torch.int32)), + ), + "test_2d_2_idx": ( + torch.rand(5, 2), + ( + torch.randint(5, size=(5,), dtype=torch.int32), + torch.randint(2, size=(5,), dtype=torch.int32), + ), + ), + "test_2d_2_2d_idx_broadcastable": ( + torch.rand(5, 2), + ( + torch.randint(5, size=(5, 3), dtype=torch.int32), + torch.randint(2, size=(1, 3), dtype=torch.int32), + ), + ), + "test_2d_2_2d_idx_broadcastable_2": ( + torch.rand(5, 2), + ( + torch.randint(5, size=(5, 1), dtype=torch.int32), + torch.randint(2, size=(3,), dtype=torch.int32), + ), + ), + "test_3d_1_idx": (torch.rand(12, 3, 7), (torch.arange(12, dtype=torch.int32),)), + "test_3d_2_idx": ( + torch.rand(12, 3, 7), + ( + torch.arange(12, dtype=torch.int32), + torch.randint(3, size=(12,), dtype=torch.int32), + ), + ), + "test_3d_3_idx": ( + torch.rand(12, 3, 7), + ( + torch.arange(12, dtype=torch.int32), + torch.randint(3, size=(12,), dtype=torch.int32), + torch.randint(7, size=(12,), dtype=torch.int32), + ), + ), + "test_4d_1_idx": ( + torch.rand(15, 3, 7, 2), + (torch.arange(15, dtype=torch.int32),), + ), + "test_4d_2_idx": ( + torch.rand(15, 3, 7, 2), + ( + torch.randint(15, size=(15,), dtype=torch.int32), + torch.randint(3, size=(1,), dtype=torch.int32), + ), + ), + "test_4d_3_idx": ( + torch.rand(15, 3, 7, 2), + ( + torch.arange(15, dtype=torch.int32), + torch.randint(3, size=(15,), dtype=torch.int32), + torch.randint(7, size=(15,), dtype=torch.int32), + ), + ), + "test_4d_4_id_broadcastable": ( + torch.rand(15, 3, 7, 2), + ( + torch.arange(15, dtype=torch.int32), + torch.randint(3, size=(3, 1), dtype=torch.int32), + torch.randint(6, size=(6, 1, 1), dtype=torch.int32), + torch.randint(2, size=(15,), dtype=torch.int32), + ), + ), + } + + # xfail - None (unsqueeze) unsupported + test_data_none: dict[input_params] = { + "test_3d_3_idx_with_none_before": ( + torch.rand(12, 3, 7), + ( + None, + torch.randint(3, size=(12,), dtype=torch.int32), + ), + ), + "test_3d_3_idx_with_2_none_before": ( + torch.rand(12, 3, 7), + ( + None, + None, + torch.randint(3, size=(12,), dtype=torch.int32), + ), + ), + "test_3d_3_idx_with_none_around": ( + torch.rand(12, 3, 7), + ( + None, + torch.randint(3, size=(12,), dtype=torch.int32), + None, + ), + ), + "test_3d_3_idx_with_none_after": ( + # Due to the information passed to the NodeVisitor and + # preceding passes, detecting this and rejecting it for + # partitioning is difficult and unreliable, as such + # it is not xfail as the existing logic can handle it. + torch.rand(12, 3, 7), + ( + torch.randint(3, size=(12,), dtype=torch.int32), + None, + ), + ), + "test_3d_3_idx_with_none_middle": ( + torch.rand(12, 3, 7), + ( + torch.randint(3, size=(12,), dtype=torch.int32), + None, + torch.randint(3, size=(12,), dtype=torch.int32), + ), + ), + } + + def forward(self, input_: torch.Tensor, indices: Tuple[None | torch.Tensor]): + return input_[indices] + + +@common.parametrize("test_data", IndexTensor.test_data) +def test_index_tensor_tosa_MI(test_data: input_params): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineMI[input_params]( + IndexTensor(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + atol=IndexTensorTestCommon.atol, + rtol=IndexTensorTestCommon.rtol, + ).run() + ) + + +@common.parametrize("test_data", IndexTensor.test_data) +def test_index_tensor_tosa_BI(test_data: input_params): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineBI[input_params]( + IndexTensor(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + ).run() + ) + + +@common.parametrize( + "test_data", + IndexTensor.test_data_none, + xfails={ + # More info in index_tensor_support.py + "test_3d_3_idx_with_none_before": "None (Unsqueeze) unsupported", + "test_3d_3_idx_with_2_none_before": "None (Unsqueeze) unsupported", + "test_3d_3_idx_with_none_around": "None (Unsqueeze) unsupported", + "test_3d_3_idx_with_none_middle": "None (Unsqueeze) unsupported", + }, +) +def test_index_tensor_tosa_MI_none(test_data: input_params): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineMI[input_params]( + IndexTensor(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + atol=IndexTensorTestCommon.atol, + rtol=IndexTensorTestCommon.rtol, + ).run() + ) + + +@common.parametrize( + "test_data", + IndexTensor.test_data_none, + xfails={ + # More info in index_tensor_support.py + "test_3d_3_idx_with_none_before": "None (Unsqueeze) unsupported", + "test_3d_3_idx_with_2_none_before": "None (Unsqueeze) unsupported", + "test_3d_3_idx_with_none_around": "None (Unsqueeze) unsupported", + "test_3d_3_idx_with_none_middle": "None (Unsqueeze) unsupported", + }, +) +def test_index_tensor_tosa_BI_none(test_data: input_params): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineBI[input_params]( + IndexTensor(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + ).run() + ) diff --git a/backends/arm/test/ops/test_logical.py b/backends/arm/test/ops/test_logical.py index 139653eea97..1a056e31b3c 100644 --- a/backends/arm/test/ops/test_logical.py +++ b/backends/arm/test/ops/test_logical.py @@ -6,7 +6,6 @@ from typing import Tuple -import pytest import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -84,7 +83,13 @@ def forward(self, tensor: torch.Tensor): @common.parametrize("test_data", And().test_data) def test_logical_and_tosa_MI(test_data: input_t2): pipeline = TosaPipelineMI[input_t2]( - And(), test_data(), And().aten_op, And().exir_op + And(), + test_data(), + And().aten_op, + And().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.run() @@ -92,7 +97,13 @@ def test_logical_and_tosa_MI(test_data: input_t2): @common.parametrize("test_data", And().test_data) def test_logical_and_tosa_BI(test_data: input_t2): pipeline = TosaPipelineBI[input_t2]( - And(), test_data(), And().aten_op, And().exir_op + And(), + test_data(), + And().aten_op, + And().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -113,11 +124,17 @@ def test_logical_and_u55_BI_not_delegated(test_data: input_t2): @common.parametrize("test_data", And().test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_logical_and_u85_BI(test_data: input_t2): pipeline = EthosU85PipelineBI[input_t2]( - And(), test_data(), And().aten_op, And().exir_op, run_on_fvp=True + And(), + test_data(), + And().aten_op, + And().exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -127,7 +144,13 @@ def test_logical_and_u85_BI(test_data: input_t2): @common.parametrize("test_data", Xor().test_data) def test_logical_xor_tosa_MI(test_data: input_t2): pipeline = TosaPipelineMI[input_t2]( - Xor(), test_data(), Xor().aten_op, Xor().exir_op + Xor(), + test_data(), + Xor().aten_op, + Xor().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.run() @@ -135,7 +158,13 @@ def test_logical_xor_tosa_MI(test_data: input_t2): @common.parametrize("test_data", Xor().test_data) def test_logical_xor_tosa_BI(test_data: input_t2): pipeline = TosaPipelineBI[input_t2]( - Xor(), test_data(), Xor().aten_op, Xor().exir_op + Xor(), + test_data(), + Xor().aten_op, + Xor().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -156,11 +185,17 @@ def test_logical_xor_u55_BI_not_delegated(test_data: input_t2): @common.parametrize("test_data", Xor().test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_logical_xor_u85_BI(test_data: input_t2): pipeline = EthosU85PipelineBI[input_t2]( - Xor(), test_data(), Xor().aten_op, Xor().exir_op, run_on_fvp=True + Xor(), + test_data(), + Xor().aten_op, + Xor().exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -169,13 +204,29 @@ def test_logical_xor_u85_BI(test_data: input_t2): @common.parametrize("test_data", Or().test_data) def test_logical_or_tosa_MI(test_data: input_t2): - pipeline = TosaPipelineMI[input_t2](Or(), test_data(), Or().aten_op, Or().exir_op) + pipeline = TosaPipelineMI[input_t2]( + Or(), + test_data(), + Or().aten_op, + Or().exir_op, + atol=0, + rtol=0, + qtol=0, + ) pipeline.run() @common.parametrize("test_data", Or().test_data) def test_logical_or_tosa_BI(test_data: input_t2): - pipeline = TosaPipelineBI[input_t2](Or(), test_data(), Or().aten_op, Or().exir_op) + pipeline = TosaPipelineBI[input_t2]( + Or(), + test_data(), + Or().aten_op, + Or().exir_op, + atol=0, + rtol=0, + qtol=0, + ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") pipeline.run() @@ -195,11 +246,17 @@ def test_logical_or_u55_BI_not_delegated(test_data: input_t2): @common.parametrize("test_data", Or().test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_logical_or_u85_BI(test_data: input_t2): pipeline = EthosU85PipelineBI[input_t2]( - Or(), test_data(), Or().aten_op, Or().exir_op, run_on_fvp=True + Or(), + test_data(), + Or().aten_op, + Or().exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -209,7 +266,13 @@ def test_logical_or_u85_BI(test_data: input_t2): @common.parametrize("test_data", Not().test_data) def test_logical_not_tosa_MI(test_data: input_t2): pipeline = TosaPipelineMI[input_t2]( - Not(), test_data(), Not().aten_op, Not().exir_op + Not(), + test_data(), + Not().aten_op, + Not().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.run() @@ -217,7 +280,13 @@ def test_logical_not_tosa_MI(test_data: input_t2): @common.parametrize("test_data", Not().test_data) def test_logical_not_tosa_BI(test_data: input_t2): pipeline = TosaPipelineBI[input_t2]( - Not(), test_data(), Not().aten_op, Not().exir_op + Not(), + test_data(), + Not().aten_op, + Not().exir_op, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") @@ -238,11 +307,17 @@ def test_logical_not_u55_BI_not_delegated(test_data: input_t2): @common.parametrize("test_data", Not().test_data) -@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.") @common.XfailIfNoCorstone320 def test_logical_not_u85_BI(test_data: input_t2): pipeline = EthosU85PipelineBI[input_t2]( - Not(), test_data(), Not().aten_op, Not().exir_op, run_on_fvp=True + Not(), + test_data(), + Not().aten_op, + Not().exir_op, + run_on_fvp=True, + atol=0, + rtol=0, + qtol=0, ) pipeline.pop_stage("quantize") pipeline.pop_stage("check.quant_nodes") diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py index e0e85a6395e..55340a565e5 100644 --- a/backends/arm/test/ops/test_max_pool.py +++ b/backends/arm/test/ops/test_max_pool.py @@ -19,21 +19,46 @@ TosaPipelineMI, ) - test_data_suite = { # (test_name, test_data, [kernel_size, stride, padding]) - "zeros": lambda: (torch.zeros(1, 1, 4, 8), [2, 2, 1]), + "zeros": lambda: (torch.zeros(1, 1, 4, 8), [(4, 6), 2, (2, 0)]), "ones": lambda: (torch.ones(1, 16, 50, 32), [4, 2, 0]), "rand": lambda: (torch.rand(1, 16, 52, 16), [4, 3, 0]), "non_divisible": lambda: (torch.rand(1, 16, 112, 112), [3, 2, 1]), "non_divisible_window_height": lambda: (torch.rand(1, 16, 56, 56), [3, (2, 1), 1]), "non_divisible_window_width": lambda: (torch.rand(1, 16, 56, 56), [3, (1, 2), 1]), + "non_divisible_ceil_mode": lambda: ( + torch.rand(1, 16, 112, 112), + [3, 2, 1, 1, True], + ), + "non_divisible_window_height_ceil_mode": lambda: ( + torch.rand(1, 16, 56, 56), + [3, (2, 1), 1, 1, True], + ), + "non_divisible_window_width_ceil_mode": lambda: ( + torch.rand(1, 16, 56, 56), + [3, (1, 2), 1, 1, True], + ), } test_data_suite_mult_batches = { "randn": lambda: (torch.randn(5, 16, 50, 32), [4, 2, 0]), } +test_data_suite_dilation = [ + # Simple dilation=2 on 8x8 input, kernel=3, stride=1, no padding + ("dilation2", torch.rand(1, 1, 8, 8), [3, 1, 0, 2]), + # Input is 6x6, kernel=3, stride=1, dilation=2. + # Padding=1 expands the effective input to 8x8. + ("pad_then_dil2", torch.rand(1, 1, 6, 6), [3, 1, 1, 2]), + # Input is 16x16, kernel=2x2, stride=2x2, dilation=1 (no dilation). + # Padding of 1 ensures the input size remains divisible by stride + # after padding. + ("even_kernel_fast", torch.rand(1, 3, 16, 16), [(2, 2), (2, 2), (1, 1), 1]), + # Multi-batch, multi-channel input (N=4, C=3), kernel=3x3, + # stride=3x3, no padding, dilation=1. + ("mb_ch_dil1", torch.rand(4, 3, 12, 12), [(3, 3), (3, 3), 0, 1]), +] aten_op = "torch.ops.aten.max_pool2d.default" exir_op = "executorch_exir_dialects_edge__ops_aten_max_pool2d_default" @@ -47,10 +72,16 @@ def __init__( kernel_size: int | Tuple[int, int], stride: int | Tuple[int, int], padding: int | Tuple[int, int], + dilation: int | Tuple[int, int] = 1, + ceil_mode: bool = False, ): super().__init__() self.max_pool_2d = torch.nn.MaxPool2d( - kernel_size=kernel_size, stride=stride, padding=padding + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, ) def forward(self, x): @@ -180,3 +211,41 @@ def test_max_pool2d_u55_BI_failure_set(test_data: Tuple): ) pipeline.pop_stage("check_count.exir") pipeline.run() + + +# Convert the list of (name, tensor, params) into the dict-of-lambdas shape +dilation_test_data = { + name: (lambda data=data, params=params: (data, params)) + for name, data, params in test_data_suite_dilation +} + + +@common.parametrize("test_data", dilation_test_data) +def test_max_pool2d_tosa_MI_dilation(test_data): + """ + TOSA MI pipeline with dilation > 1 (and dilation=1 sanity cases). + """ + data, model_params = test_data() + pipeline = TosaPipelineMI[input_t1]( + MaxPool2d(*model_params), + (data,), + aten_op, + exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", dilation_test_data) +def test_max_pool2d_tosa_BI_dilation(test_data): + """ + TOSA BI pipeline with dilation > 1 (and dilation=1 sanity cases). + """ + data, model_params = test_data() + pipeline = TosaPipelineBI[input_t1]( + MaxPool2d(*model_params), + (data,), + aten_op, + exir_op, + symmetric_io_quantization=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_multihead_attention.py b/backends/arm/test/ops/test_multihead_attention.py new file mode 100644 index 00000000000..e23aff0b9dc --- /dev/null +++ b/backends/arm/test/ops/test_multihead_attention.py @@ -0,0 +1,96 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + + +class MultiheadAttention(torch.nn.MultiheadAttention): + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) + + +input_t1 = tuple[torch.Tensor, torch.nn.Module] +test_suite = { + # test_name, (x,), embed_dim, num_heads, batch_first + "rand_2d": lambda: ( + (torch.rand(6, 3),), + MultiheadAttention(embed_dim=3, num_heads=3, batch_first=True), + ), + "randn_2d": lambda: ( + (torch.randn(2, 4),), + MultiheadAttention(embed_dim=4, num_heads=2, batch_first=True), + ), + "randn_3d": lambda: ( + (torch.randn(3, 2, 4),), + MultiheadAttention(embed_dim=4, num_heads=2, batch_first=False), + ), +} + + +@common.parametrize( + "test_data", + test_suite, +) +def test_multihead_attention_tosa_MI(test_data: input_t1): + test_data, module = test_data() + pipeline = TosaPipelineMI(module, (*test_data, *test_data, *test_data), [], []) + pipeline.run() + + +@common.parametrize( + "test_data", + test_suite, +) +def test_multihead_attention_tosa_BI(test_data): + test_data, module = test_data() + pipeline = TosaPipelineBI(module, (*test_data, *test_data, *test_data), [], []) + pipeline.run() + + +@common.parametrize( + "test_data", + test_suite, +) +@pytest.mark.xfail(reason="MLETORCH-1102: Numerical issues on FVP") +@common.XfailIfNoCorstone300 +def test_multihead_attention_u55_BI(test_data: input_t1): + test_data, module = test_data() + pipeline = EthosU55PipelineBI( + module, + (*test_data, *test_data, *test_data), + [], + [], + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + pipeline.pop_stage("check_count.exir") + pipeline.run() + + +@common.parametrize( + "test_data", + test_suite, +) +@pytest.mark.xfail(reason="MLETORCH-1102: Numerical issues on FVP") +@common.XfailIfNoCorstone320 +def test_multihead_attention_u85_BI(test_data: input_t1): + test_data, module = test_data() + pipeline = EthosU85PipelineBI( + module, + (*test_data, *test_data, *test_data), + [], + [], + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_round.py b/backends/arm/test/ops/test_round.py new file mode 100644 index 00000000000..3480076a3e1 --- /dev/null +++ b/backends/arm/test/ops/test_round.py @@ -0,0 +1,84 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import pytest +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +input_t1 = Tuple[torch.Tensor] # Input x + +aten_op = "torch.ops.aten.round.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_round_default" + +test_data_suite = { + # (test_name, test_data) + "zeros": lambda: torch.zeros(1, 10, 10, 10), + "ones": lambda: torch.ones(10, 10, 10), + "rand": lambda: torch.rand(10, 10) - 0.5, + "randn_pos": lambda: torch.randn(10) + 10, + "randn_neg": lambda: torch.randn(10) - 10, + "ramp": lambda: torch.arange(-16, 16, 0.2), +} + + +class Round(torch.nn.Module): + def forward(self, x: torch.Tensor): + return x.round() + + +@common.parametrize("test_data", test_data_suite) +def test_round_tosa_MI(test_data: torch.Tensor): + pipeline = TosaPipelineMI[input_t1]( + Round(), + (test_data(),), + aten_op, + exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_round_tosa_BI(test_data: torch.Tensor): + pipeline = TosaPipelineBI[input_t1]( + Round(), + (test_data(),), + [], + exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +@pytest.mark.xfail(reason="where.self not supported on U55") +def test_round_u55_BI(test_data: torch.Tensor): + pipeline = EthosU55PipelineBI[input_t1]( + Round(), + (test_data(),), + [], + exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_round_u85_BI(test_data: torch.Tensor): + pipeline = EthosU85PipelineBI[input_t1]( + Round(), + (test_data(),), + [], + exir_op, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_sum.py b/backends/arm/test/ops/test_sum.py index 59ebcc15270..c1e958174cf 100644 --- a/backends/arm/test/ops/test_sum.py +++ b/backends/arm/test/ops/test_sum.py @@ -33,6 +33,7 @@ class Sum(torch.nn.Module): "4d_dims_no_keep": lambda: (torch.rand(1, 1, 5, 8), 1, False), "4d_dim_3_keep": lambda: (torch.rand(1, 2, 3, 4), 3, True), "4d_dims_keep": lambda: (torch.rand(1, 2, 8, 8), [2, 3, 0], True), + "dim_None": lambda: (torch.rand(10), None, True), } def forward(self, x: torch.Tensor, dim: int, keepdim: bool): diff --git a/backends/arm/test/ops/test_where.py b/backends/arm/test/ops/test_where.py index 7bfd27ac0a8..a60cf587a3e 100644 --- a/backends/arm/test/ops/test_where.py +++ b/backends/arm/test/ops/test_where.py @@ -121,6 +121,12 @@ def scalar_condition(input: torch.Tensor): scalar_condition, ) +int32_scalar_cond = Where( + 1, + torch.int32, + scalar_condition, +) + test_modules_common = { "two_dim_tensor_cond": lambda: two_dim_tensor_cond, "three_dim_tensor_cond": lambda: three_dim_tensor_cond, @@ -134,6 +140,7 @@ def scalar_condition(input: torch.Tensor): **test_modules_common, "float32_tensor_cond_tuple_dtype": lambda: float32_tensor_cond_tuple_dtype, "float32_tensor_cond_tuple_dtype_bool": lambda: float32_tensor_cond_tuple_dtype_bool, + "int32_scalar_cond": lambda: int32_scalar_cond, } test_modules_BI = { diff --git a/backends/arm/test/passes/test_cast_int64_pass.py b/backends/arm/test/passes/test_cast_int64_pass.py index b9ddfcdec86..7832fd87ed9 100644 --- a/backends/arm/test/passes/test_cast_int64_pass.py +++ b/backends/arm/test/passes/test_cast_int64_pass.py @@ -11,6 +11,8 @@ from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.test.harness.stages import StageType + input_t = Tuple[torch.Tensor] # Input x @@ -40,6 +42,8 @@ def test_int64_model(test_data: input_t): ) pipeline.run() - exported_program = pipeline.tester.get_artifact("RunPasses").exported_program() + exported_program = pipeline.tester.get_artifact( + StageType.RUN_PASSES + ).exported_program() for state in exported_program.state_dict: assert exported_program.state_dict[state].dtype == torch.int32 diff --git a/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py b/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py index 49626eefb71..9a26157ed7e 100644 --- a/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py @@ -10,7 +10,10 @@ from executorch.backends.arm._passes.fuse_equal_placeholders_pass import ( FuseEqualPlaceholdersPass, ) -from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.test.tester.test_pipeline import ( + PassPipeline, + TosaPipelineMI, +) input_t = Tuple[torch.Tensor] # Input x @@ -54,6 +57,25 @@ def forward(self, x): return self.fc1(x) + self.fc2(x) +class NotFuseTensorWithDifferentType(torch.nn.Module): + + ops_before_pass = {} + ops_after_pass = {} + ops_not_after_pass = [] + + def forward(self, x: torch.Tensor, y: torch.Tensor): + """ + Args: + x: A float tensor (dtype=torch.float32) + y: An int tensor (dtype=torch.int32) + """ + a = torch.tensor(1.0, dtype=torch.float32) + b = torch.tensor(1, dtype=torch.int32) + m = x < a + n = y > b + return m, n + + def test_fuse_equal_placeholders_constants_tosa_MI(): module = FuseWeightsConstants() data = (torch.rand(1, 2, 8),) @@ -94,3 +116,24 @@ def test_fuse_equal_placeholders_state_dict_tosa_MI(): assert len(state_dict_keys) == 2, "FuseEqualPlaceholders state_dict failed" assert "_common" in state_dict_keys[0], "FuseEqualPlaceholders state_dict failed" assert "_common" in state_dict_keys[1], "FuseEqualPlaceholders state_dict failed" + + +def test_not_fuse_tensor_with_different_type_MI(): + module = NotFuseTensorWithDifferentType() + data = ( + torch.rand( + 1, + ), + torch.randint( + 0, + 10, + (1,), + dtype=torch.int, + ), + ) + pipeline = TosaPipelineMI[input_t]( + module, + data, + aten_op=[], + ) + pipeline.run() diff --git a/backends/arm/test/quantizer/test_generic_annotater.py b/backends/arm/test/quantizer/test_generic_annotater.py index d87521485e5..4a4a333084c 100644 --- a/backends/arm/test/quantizer/test_generic_annotater.py +++ b/backends/arm/test/quantizer/test_generic_annotater.py @@ -9,6 +9,7 @@ import torch from executorch.backends.arm.quantizer import is_annotated from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineBI +from executorch.backends.test.harness.stages import StageType from torch.fx.passes.utils.source_matcher_utils import get_source_partitions @@ -36,7 +37,7 @@ def check_annotation(model): pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() - artifact = pipeline.tester.get_artifact("Quantize") + artifact = pipeline.tester.get_artifact(StageType.QUANTIZE) partitions = get_source_partitions(artifact.graph, [model.op]) partitions = list(itertools.chain.from_iterable(partitions.values())) diff --git a/backends/arm/test/tester/analyze_output_utils.py b/backends/arm/test/tester/analyze_output_utils.py index d4c9d8e8dc0..96060b7b563 100644 --- a/backends/arm/test/tester/analyze_output_utils.py +++ b/backends/arm/test/tester/analyze_output_utils.py @@ -14,7 +14,7 @@ get_output_quantization_params, ) -from executorch.backends.xnnpack.test.tester.tester import Export, Quantize +from executorch.backends.test.harness.stages import StageType logger = logging.getLogger(__name__) @@ -22,22 +22,36 @@ def _print_channels(result, reference, channels_close, C, H, W, rtol, atol): output_str = "" + booldata = False + if reference.dtype == torch.bool or result.dtype == torch.bool: + booldata = True + for c in range(C): if channels_close[c]: continue - - max_diff = torch.max(torch.abs(reference - result)) - exp = f"{max_diff:2e}"[-3:] - output_str += f"channel {c} (e{exp})\n" + if not booldata: + max_diff = torch.max(torch.abs(reference - result)) + exp = f"{max_diff:2e}"[-3:] + output_str += f"channel {c} (e{exp})\n" + else: + max_diff = torch.max(reference ^ result) + output_str += f"channel {c} (bool)\n" for y in range(H): res = "[" for x in range(W): if torch.allclose(reference[c, y, x], result[c, y, x], rtol, atol): - res += " . " + if not booldata: + res += " . " + else: + res += " . " else: - diff = (reference[c, y, x] - result[c, y, x]) / 10 ** (int(exp)) - res += f"{diff: .2f} " + if not booldata: + diff = (reference[c, y, x] - result[c, y, x]) / 10 ** (int(exp)) + res += f"{diff: .2f} " + else: + diff = reference[c, y, x] ^ result[c, y, x] + res += " X " # Break early for large widths if x == 16: @@ -157,12 +171,6 @@ def print_error_diffs( result_batch = result[n, :, :, :] reference_batch = reference[n, :, :, :] - if reference_batch.dtype == torch.bool or result_batch.dtype == torch.bool: - mismatches = (reference_batch != result_batch).sum().item() - total = reference_batch.numel() - output_str += f"(BOOLEAN tensor) {mismatches} / {total} elements differ ({mismatches / total:.2%})\n" - continue - is_close = torch.allclose(result_batch, reference_batch, rtol, atol) if is_close: output_str += ".\n" @@ -189,6 +197,11 @@ def print_error_diffs( output_str += _print_elements( result[n, :, :, :], reference[n, :, :, :], C, H, W, rtol, atol ) + if reference_batch.dtype == torch.bool or result_batch.dtype == torch.bool: + mismatches = (reference_batch != result_batch).sum().item() + total = reference_batch.numel() + output_str += f"(BOOLEAN tensor) {mismatches} / {total} elements differ ({mismatches / total:.2%})\n" + # Only compute numeric error metrics if tensor is not boolean if reference.dtype != torch.bool and result.dtype != torch.bool: reference_range = torch.max(reference) - torch.min(reference) @@ -238,8 +251,8 @@ def dump_error_output( if path_to_tosa_files is None: path_to_tosa_files = tempfile.mkdtemp(prefix="executorch_result_dump_") - export_stage = tester.stages.get(tester.stage_name(Export), None) - quantize_stage = tester.stages.get(tester.stage_name(Quantize), None) + export_stage = tester.stages.get(StageType.EXPORT, None) + quantize_stage = tester.stages.get(StageType.QUANTIZE, None) if export_stage is not None and quantize_stage is not None: output_nodes = get_output_nodes(export_stage.artifact) qp_input = get_input_quantization_params(export_stage.artifact) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index ccec1019144..04034521f9b 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -61,6 +61,7 @@ from executorch.backends.arm.tosa_partitioner import TOSAPartitioner from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.test.harness.stages import Stage, StageType from executorch.backends.xnnpack.test.tester import Tester from executorch.devtools.backend_debug import get_delegation_info @@ -259,10 +260,13 @@ def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram: super().run(artifact, inputs) -class InitialModel(tester.Stage): +class InitialModel(Stage): def __init__(self, model: torch.nn.Module): self.model = model + def stage_type(self) -> StageType: + return StageType.INITIAL_MODEL + def run(self, artifact, inputs=None) -> None: pass @@ -305,16 +309,19 @@ def __init__( self.constant_methods = constant_methods self.compile_spec = compile_spec super().__init__(model, example_inputs, dynamic_shapes) - self.pipeline[self.stage_name(InitialModel)] = [ - self.stage_name(tester.Quantize), - self.stage_name(tester.Export), + self.pipeline[StageType.INITIAL_MODEL] = [ + StageType.QUANTIZE, + StageType.EXPORT, ] # Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry. - self.stages[self.stage_name(InitialModel)] = None + self.stages[StageType.INITIAL_MODEL] = None self._run_stage(InitialModel(self.original_module)) - def quantize(self, quantize_stage: Optional[tester.Quantize] = None): + def quantize( + self, + quantize_stage: Optional[tester.Quantize] = None, + ): if quantize_stage is None: quantizer = None if is_tosa(self.compile_spec): @@ -324,7 +331,7 @@ def quantize(self, quantize_stage: Optional[tester.Quantize] = None): quantizer = EthosUQuantizer(self.compile_spec) quantize_stage = tester.Quantize( quantizer, - get_symmetric_quantization_config(is_per_channel=False), + get_symmetric_quantization_config(), ) return super().quantize(quantize_stage) @@ -410,7 +417,7 @@ def serialize( return super().serialize(serialize_stage) def is_quantized(self) -> bool: - return self.stages[self.stage_name(tester.Quantize)] is not None + return self.stages[StageType.QUANTIZE] is not None def run_method_and_compare_outputs( self, @@ -439,18 +446,16 @@ def run_method_and_compare_outputs( """ if not run_eager_mode: - edge_stage = self.stages[self.stage_name(tester.ToEdge)] + edge_stage = self.stages[StageType.TO_EDGE] if edge_stage is None: - edge_stage = self.stages[ - self.stage_name(tester.ToEdgeTransformAndLower) - ] + edge_stage = self.stages[StageType.TO_EDGE_TRANSFORM_AND_LOWER] assert ( edge_stage is not None ), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run." else: # Run models in eager mode. We do this when we want to check that the passes # are numerically accurate and the exported graph is correct. - export_stage = self.stages[self.stage_name(tester.Export)] + export_stage = self.stages[StageType.EXPORT] assert ( export_stage is not None ), "To compare outputs in eager mode, the model must be at Export stage" @@ -460,11 +465,11 @@ def run_method_and_compare_outputs( is_quantized = self.is_quantized() if is_quantized: - reference_stage = self.stages[self.stage_name(tester.Quantize)] + reference_stage = self.stages[StageType.QUANTIZE] else: - reference_stage = self.stages[self.stage_name(InitialModel)] + reference_stage = self.stages[StageType.INITIAL_MODEL] - exported_program = self.stages[self.stage_name(tester.Export)].artifact + exported_program = self.stages[StageType.EXPORT].artifact output_nodes = get_output_nodes(exported_program) output_qparams = get_output_quantization_params(output_nodes) @@ -474,7 +479,7 @@ def run_method_and_compare_outputs( quantization_scales.append(getattr(output_qparams[node], "scale", None)) logger.info( - f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'" + f"Comparing Stage '{test_stage.stage_type()}' with Stage '{reference_stage.stage_type()}'" ) # Loop inputs and compare reference stage with the compared stage. @@ -491,7 +496,6 @@ def run_method_and_compare_outputs( reference_outputs, _ = pytree.tree_flatten( reference_stage.run_artifact(reference_input) ) - if run_eager_mode: # Run exported module directly test_outputs, _ = pytree.tree_flatten( @@ -505,6 +509,10 @@ def run_method_and_compare_outputs( test_stage.run_artifact(reference_input) ) + logger.info(f"\n Input: {reference_input}") + logger.info(f"\n Ref output: {reference_outputs}") + logger.info(f"\nTest output: {test_outputs}") + for reference_output, test_output, quantization_scale in zip( reference_outputs, test_outputs, quantization_scales ): @@ -525,14 +533,12 @@ def get_graph(self, stage: str | None = None) -> Graph: stage = self.cur artifact = self.get_artifact(stage) if ( - self.cur == self.stage_name(tester.ToEdge) - or self.cur == self.stage_name(Partition) - or self.cur == self.stage_name(ToEdgeTransformAndLower) + self.cur == StageType.TO_EDGE + or self.cur == StageType.PARTITION + or self.cur == StageType.TO_EDGE_TRANSFORM_AND_LOWER ): graph = artifact.exported_program().graph - elif self.cur == self.stage_name(tester.Export) or self.cur == self.stage_name( - tester.Quantize - ): + elif self.cur == StageType.EXPORT or self.cur == StageType.QUANTIZE: graph = artifact.graph else: raise RuntimeError( @@ -553,13 +559,13 @@ def dump_operator_distribution( Returns self for daisy-chaining. """ line = "#" * 10 - to_print = f"{line} {self.cur.capitalize()} Operator Distribution {line}\n" + to_print = f"{line} {self.cur} Operator Distribution {line}\n" if ( self.cur in ( - self.stage_name(tester.Partition), - self.stage_name(ToEdgeTransformAndLower), + StageType.PARTITION, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, ) and print_table ): @@ -599,9 +605,7 @@ def dump_dtype_distribution( """ line = "#" * 10 - to_print = ( - f"{line} {self.cur.capitalize()} Placeholder Dtype Distribution {line}\n" - ) + to_print = f"{line} {self.cur} Placeholder Dtype Distribution {line}\n" graph = self.get_graph(self.cur) tosa_spec = get_tosa_spec(self.compile_spec) @@ -650,7 +654,7 @@ def run_transform_for_annotation_pipeline( stage = self.cur # We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run. artifact = self.get_artifact(stage) - if self.cur == self.stage_name(tester.Export): + if self.cur == StageType.EXPORT: new_gm = ArmPassManager(get_tosa_spec(self.compile_spec)).transform_for_annotation_pipeline( # type: ignore[arg-type] graph_module=artifact.graph_module ) diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 23d51dcaba1..7f0ad5ce8c8 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -299,6 +299,7 @@ def __init__( run_on_tosa_ref_model: bool = True, tosa_version: str = "TOSA-0.80+BI", symmetric_io_quantization: bool = False, + per_channel_quantization: bool = False, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, atol: float = 1e-03, @@ -315,16 +316,17 @@ def __init__( compile_spec = common.get_tosa_compile_spec( tosa_profiles[tosa_version], custom_path=custom_path ) - quant_stage = ( - Quantize( - TOSAQuantizer(tosa_profiles[tosa_version]).set_io( - get_symmetric_quantization_config() - ), - get_symmetric_quantization_config(), + if symmetric_io_quantization or per_channel_quantization: + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantization_config = get_symmetric_quantization_config( + is_per_channel=per_channel_quantization ) - if symmetric_io_quantization - else None - ) + if symmetric_io_quantization: + quantizer.set_io(quantization_config) + quant_stage = Quantize(quantizer, quantization_config) + else: + quant_stage = None + super().__init__( module, test_data, @@ -472,6 +474,7 @@ def __init__( exir_ops: Optional[str | List[str]] = None, run_on_fvp: bool = True, symmetric_io_quantization: bool = False, + per_channel_quantization: bool = False, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, atol: float = 1e-03, @@ -479,16 +482,17 @@ def __init__( qtol: int = 1, ): compile_spec = common.get_u55_compile_spec(custom_path=custom_path) - quant_stage = ( - Quantize( - EthosUQuantizer(compile_spec).set_io( - get_symmetric_quantization_config() - ), - get_symmetric_quantization_config(), + if symmetric_io_quantization or per_channel_quantization: + quantizer = EthosUQuantizer(compile_spec) + quantization_config = get_symmetric_quantization_config( + is_per_channel=per_channel_quantization ) - if symmetric_io_quantization - else None - ) + if symmetric_io_quantization: + quantizer.set_io(quantization_config) + quant_stage = Quantize(quantizer, quantization_config) + else: + quant_stage = None + super().__init__( module, test_data, @@ -560,6 +564,7 @@ def __init__( exir_ops: str | List[str] = None, run_on_fvp: bool = True, symmetric_io_quantization: bool = False, + per_channel_quantization: bool = False, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, atol: float = 1e-03, @@ -567,16 +572,17 @@ def __init__( qtol: int = 1, ): compile_spec = common.get_u85_compile_spec(custom_path=custom_path) - quant_stage = ( - Quantize( - EthosUQuantizer(compile_spec).set_io( - get_symmetric_quantization_config() - ), - get_symmetric_quantization_config(), + if symmetric_io_quantization or per_channel_quantization: + quantizer = EthosUQuantizer(compile_spec) + quantization_config = get_symmetric_quantization_config( + is_per_channel=per_channel_quantization ) - if symmetric_io_quantization - else None - ) + if symmetric_io_quantization: + quantizer.set_io(quantization_config) + quant_stage = Quantize(quantizer, quantization_config) + else: + quant_stage = None + super().__init__( module, test_data, diff --git a/backends/arm/tosa_backend.py b/backends/arm/tosa_backend.py index fdada0b889a..0f03e12c916 100644 --- a/backends/arm/tosa_backend.py +++ b/backends/arm/tosa_backend.py @@ -113,6 +113,8 @@ def preprocess( # noqa: C901 if node.op == "call_function": process_call_function(node, tosa_graph, node_visitors, tosa_spec) elif node.op == "placeholder": + if len(node.users) == 0: + continue process_placeholder(node, tosa_graph, edge_program, tosa_spec) if node.name in edge_program.graph_signature.user_inputs: input_count += 1 diff --git a/backends/arm/tosa_partitioner.py b/backends/arm/tosa_partitioner.py index de85dfae92f..ee7d1733f37 100644 --- a/backends/arm/tosa_partitioner.py +++ b/backends/arm/tosa_partitioner.py @@ -104,12 +104,14 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: if not is_partitioned(input): del node.meta["delegation_tag"] break + continue if is_dequant_node(node): for user in node.users: if not is_partitioned(user): del node.meta["delegation_tag"] break + continue if tosa_spec.support_float(): continue diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index 10dc810da6b..aad4bab3eb1 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -23,9 +23,23 @@ from tosa.RoundingMode import RoundingMode # type: ignore -q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default -dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default -dq_q_ops = (q_op, dq_op) +q_ops = ( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, +) +dq_ops = ( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, +) +per_tensor_q_dq_ops = ( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, +) +per_channel_q_dq_ops = ( + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, +) +dq_q_ops = (*q_ops, *dq_ops) def insert_rescale_ops_to_int32( @@ -61,14 +75,14 @@ def insert_rescale_ops_to_int32( # Scale the int8 quantized input to a common scale in the integer # domain - min_scale = min([qarg.scale for qarg in qargs]) - scales = [qarg.scale / min_scale for qarg in qargs] + min_scale = min([qarg.get_scale_per_tensor() for qarg in qargs]) + scales = [qarg.get_scale_per_tensor() / min_scale for qarg in qargs] rescaled_nodes: list[Any] = [] for tensor, qarg, scale in zip(tensors, qargs, scales): rescaled_nodes.append( build_rescale_to_int32( - tosa_graph, tensor, qarg.zp, [scale], tosa_spec=tosa_spec + tosa_graph, tensor, qarg.get_zp_per_tensor(), scale, tosa_spec=tosa_spec ) ) return rescaled_nodes, min_scale @@ -100,51 +114,125 @@ def insert_rescale_op_to_int8( assert len(output_qparams) == 1, "More than one output not supported" qargs_out = output_qparams[0] - output_rescale_scale = scale / qargs_out.scale + output_rescale_scale = scale / qargs_out.get_scale_per_tensor() # Rescale Back to INT8 build_rescale_from_int32( tosa_graph, last_tensor, node.name, - qargs_out.zp, - [output_rescale_scale], + qargs_out.get_zp_per_tensor(), + output_rescale_scale, tosa_spec=tosa_spec, ) class QuantArgs(NamedTuple): - scale: float - zp: int + scale: list[float] | float + zp: list[int] | int qmin: int qmax: int dtype: torch.dtype + axis: int = 0 + per_channel: bool = False def quantize_value(self, x: torch.Tensor | float) -> Tensor: + """Quantizes the input tensor or value to a quantized tensor. If the input is + not a tensor, it is converted to a tensor first. If self.per_channel is True, + the quantization is done per channel, otherwise it is done per tensor. + """ if not isinstance(x, torch.Tensor): x = torch.Tensor([x]) - return torch.clip( - torch.round(x / self.scale) + self.zp, - self.qmin, - self.qmax, - ).to(self.dtype) + x = x.to(torch.float32) + if self.per_channel: + q_op = exir_ops.edge.quantized_decomposed.quantize_per_channel.default + args = (x, self.scale, self.zp, self.axis, self.qmin, self.qmax, self.dtype) + else: + q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + args = (x, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment] + + return q_op(*args) def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor: - return (qx.to(torch.int64) - self.zp) * self.scale + """Dequantizes the input tensor or value to a dequantized tensor If the input + is not a tensor, it is converted to a tensor first. If self.per_channel is True, + the dequantization is done per channel, otherwise it is done per tensor. + """ + if self.per_channel: + dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_channel.default + args = ( + qx, + self.scale, + self.zp, + self.axis, + self.qmin, + self.qmax, + self.dtype, + ) + else: + dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + args = (qx, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment] + + return dq_op(*args) @classmethod def from_operator(cls, op, args): - if op in dq_q_ops: + if op in per_tensor_q_dq_ops: return cls( scale=cast(float, args[1]), zp=cast(int, args[2]), qmin=cast(int, args[3]), qmax=cast(int, args[4]), dtype=cast(torch.dtype, args[5]), + axis=0, + per_channel=False, ) + elif op in per_channel_q_dq_ops: + return cls( + scale=cast(list[float], args[1].tolist()), + zp=cast(list[int], args[2].tolist()), + axis=cast(int, args[3]), + qmin=cast(int, args[4]), + qmax=cast(int, args[5]), + dtype=cast(torch.dtype, args[6]), + per_channel=True, + ) + else: - # We're only handling per tensor quantization - raise NotImplementedError + # We're only handling per tensor and per channel quantization + raise NotImplementedError(f"Unsupported quantization operation: {op}") + + def get_scale_per_tensor(self) -> float: + if not isinstance(self.scale, float): + raise TypeError( + f"Expected scale {self.scale} to be a float but found scale of " + f"type {type(self.scale)}" + ) + return self.scale + + def get_zp_per_tensor(self) -> int: + if not isinstance(self.zp, int): + raise TypeError( + f"Expected zero point {self.zp} to be an int but found zp of " + f"type {type(self.zp)}" + ) + return self.zp + + def get_scale_per_channel(self) -> list[float]: + if not isinstance(self.scale, list): + raise TypeError( + f"Expected scale {self.scale} to be a list but found scale of " + f"type {type(self.scale)}" + ) + return self.scale + + def get_zp_per_channel(self) -> list[int]: + if not isinstance(self.zp, list): + raise TypeError( + f"Expected zero point {self.zp} to be a list but found zp of " + f"type {type(self.zp)}" + ) + return self.zp # TOSA uses the RESCALE operation to scale between values with differing precision. @@ -200,8 +288,8 @@ def build_rescale_v0_80( input_node: Any, output_name: str, output_type: Any, - input_zp: int, - output_zp: int, + input_zp: list[int], + output_zp: list[int], is_double_round: bool = False, per_channel=False, ): @@ -215,8 +303,8 @@ def build_rescale_v0_80( attr_rescale = ts.TosaSerializerAttribute() attr_rescale.RescaleAttribute( - input_zp=input_zp, - output_zp=output_zp, + input_zp=input_zp[0], + output_zp=output_zp[0], multiplier=multipliers, shift=shifts, scale32=is_scale32, @@ -258,10 +346,10 @@ def create_const_ops_for_rescale( (len(shifts),), ts.DType.INT8, shifts, name=node_name + "_shifts" ) input_zp = tosa_fb.addConst( - [1], input_dtype, [input_zp], name=node_name + "_input_zp" + [1], input_dtype, input_zp, name=node_name + "_input_zp" ) output_zp = tosa_fb.addConst( - [1], output_dtype, [output_zp], name=node_name + "_output_zp" + [1], output_dtype, output_zp, name=node_name + "_output_zp" ) return [multipliers.name, shifts.name, input_zp.name, output_zp.name] @@ -273,8 +361,8 @@ def build_rescale( input_node: Any, output_name: str, output_type: Any, - input_zp: int, - output_zp: int, + input_zp: list[int], + output_zp: list[int], rounding_mode: RoundingMode, per_channel=False, ): @@ -319,7 +407,7 @@ def build_rescale_to_int32( tosa_fb: Any, input_arg: TosaArg, input_zp: int, - rescale_scale: list[float], + rescale_scale: float, is_scale32: bool = True, is_double_round: bool = False, per_channel: bool = False, @@ -336,12 +424,12 @@ def build_rescale_to_int32( build_rescale_v0_80( tosa_fb=tosa_fb, - scale=rescale_scale, + scale=[rescale_scale], input_node=input_arg, output_name=input_A_rescaled_to_int32.name, output_type=ts.DType.INT32, - input_zp=input_zp, - output_zp=0, + input_zp=[input_zp], + output_zp=[0], ) # type: ignore[call-arg] elif isinstance(tosa_spec, tosa_specification.Tosa_1_00): @@ -355,12 +443,12 @@ def build_rescale_to_int32( build_rescale( tosa_fb, - rescale_scale, + [rescale_scale], input_arg, input_A_rescaled_to_int32.name, ts.DType.INT32, - input_zp, - 0, + [input_zp], + [0], rounding_mode=RoundingMode.SINGLE_ROUND, ) # type: ignore[call-arg] @@ -372,7 +460,7 @@ def build_rescale_from_int32( input_node: TosaArg, output_name: str, output_zp: int, - rescale_scale: list[float], + rescale_scale: float, is_scale32: bool = True, is_double_round: bool = False, per_channel: bool = False, @@ -384,12 +472,12 @@ def build_rescale_from_int32( build_rescale_v0_80( tosa_fb=tosa_fb, - scale=rescale_scale, + scale=[rescale_scale], input_node=input_node, output_name=output_name, output_type=ts.DType.INT8, - input_zp=0, - output_zp=output_zp, + input_zp=[0], + output_zp=[output_zp], ) # type: ignore[call-arg] elif isinstance(tosa_spec, tosa_specification.Tosa_1_00): @@ -399,12 +487,12 @@ def build_rescale_from_int32( # to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale build_rescale( tosa_fb, - rescale_scale, + [rescale_scale], input_node, output_name=output_name, output_type=ts.DType.INT8, - input_zp=0, - output_zp=output_zp, + input_zp=[0], + output_zp=[output_zp], rounding_mode=RoundingMode.SINGLE_ROUND, ) # type: ignore[call-arg] return @@ -421,7 +509,7 @@ def build_rescale_conv_output( input_scale: list[float], weight_scale: list[float], output_scale: list[float], - output_zp: int, + output_zp: list[int], tosa_spec=None, ): # TODO add check to verify if this is a Per-channel quantization. @@ -438,7 +526,7 @@ def build_rescale_conv_output( input_node=op, output_name=output_name, output_type=output_type, - input_zp=0, + input_zp=[0], output_zp=output_zp, per_channel=isinstance(weight_scale, torch.Tensor), ) # type: ignore[call-arg] @@ -451,7 +539,7 @@ def build_rescale_conv_output( input_node=op, output_name=output_name, output_type=output_type, - input_zp=0, + input_zp=[0], output_zp=output_zp, rounding_mode=RoundingMode.SINGLE_ROUND, per_channel=isinstance(weight_scale, torch.Tensor), diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index a176bc62973..3b56fdd1cbf 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -15,14 +15,19 @@ import torch import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore -from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_mapping import extract_tensor_meta, TosaArg +from executorch.backends.arm.tosa_specification import ( + Tosa_0_80, + Tosa_1_00, + TosaSpecification, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.print_program import inspect_node + +from torch._subclasses.fake_tensor import FakeTensor from torch.fx import Node -from tosa_tools.v0_80.serializer.tosa_serializer import TosaOp logger = logging.getLogger(__name__) @@ -116,17 +121,149 @@ def get_output_node(node: Node) -> Node: def build_reshape(tosa_fb, input_name, new_shape, output_name): attr = ts.TosaSerializerAttribute() attr.ReshapeAttribute(new_shape) - tosa_fb.addOperator(TosaOp.Op().RESHAPE, [input_name], [output_name], attr) + tosa_fb.addOperator(ts.TosaOp.Op().RESHAPE, [input_name], [output_name], attr) + + +def are_fake_tensors_broadcastable( + fake_tensors: list[FakeTensor], +) -> tuple[bool, list[int]]: + """ + Determines whether a list of FakeTensors can be broadcast together. + Args: + fake_tensors (list[FakeTensor]): List of 2 or more FakeTensors + who's shapes to evaluate + Returns: + tuple[bool, list[int]]: First element is whether the shapes are + broadcastable. Second element is the common shape if compatible. + If not, empty list. -def build_reshape_tosa_1_0(tosa_graph, input_name, new_shape, output_name): + Raises: + RuntimeError: If less than 2 tensors are passed in. + """ + if len(fake_tensors) < 1: + raise RuntimeError(f"Expected 2 or more tensors got {len(fake_tensors)}") + + reversed_shapes = [list(reversed(ft.shape)) for ft in fake_tensors] + sorted_shapes = sorted(reversed_shapes, key=len, reverse=True) + + broadcast_shape = [] + for dim in range(len(sorted_shapes[0])): + curr_dim = 1 + for shape in sorted_shapes: + if dim >= len(shape): + continue + if curr_dim == 1 and shape[dim] != 1: + curr_dim = shape[dim] + elif shape[dim] == 1: + continue + elif curr_dim != 1 and shape[dim] != curr_dim: + return (False, []) + broadcast_shape.append(curr_dim) + return (True, list(reversed(broadcast_shape))) + + +def broadcast_tensors( + tosa_fb, nodes: list[Node], tosa_spec: TosaSpecification +) -> list[Any]: + """ + Given a list of nodes it determines the common shape they broadcast to + and adds the necessary reshape and tile operations to perform the broadcast. + + Args: + tosa_fb: Tosa graph to add nodes to + nodes (list[Node]): List of nodes to broadcast together + tosa_spec (TosaSpecification): Tosa spec + + Returns: + list[Any]: List containing the fx.Nodes or TosaSerializerTensors + of the right common shape. Order of output matches order of input. + + Raises: + RuntimeError: If the supplied nodes are not broadcastable. + + Note: + This function and `reshape_for_broadcast` both reshape the tensors + for broadcast. However this function also performs the broadcast and + does not have a limit on only two input tensors. + """ + if isinstance(tosa_spec, Tosa_0_80): + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + reshape_helper = build_reshape + elif isinstance(tosa_spec, Tosa_1_00): + import serializer.tosa_serializer as ts + + reshape_helper = build_reshape_tosa_1_0 + else: + raise ValueError(f"Unsupported TOSA spec: {tosa_spec}") + + index_fake_tensors = [node.meta["val"] for node in nodes] + broadcastable, common_shape = are_fake_tensors_broadcastable(index_fake_tensors) + if not broadcastable: + raise RuntimeError("FakeTensors are not broadcastable") + + broadcast_tensors = [] + for node in nodes: + tens_dtype, tens_shape, _ = extract_tensor_meta(node.meta, tosa_spec) + list_tens_shape = list(tens_shape) + # Already in the right shape we can just add it to the list. + if list_tens_shape == common_shape: + broadcast_tensors.append(node) + continue + + rank_diff = len(common_shape) - len(tens_shape) + new_shape = [1] * rank_diff + list_tens_shape + reshaped = tosa_fb.addIntermediate( + new_shape, + tens_dtype, + ) + + reshape_helper(tosa_fb, node.name, new_shape, reshaped.name) + + tiled = tosa_fb.addIntermediate(common_shape, tens_dtype) + multipliers = [ + comm if curr == 1 else 1 for comm, curr in zip(common_shape, new_shape) + ] + if isinstance(tosa_spec, Tosa_0_80): + attr = ts.TosaSerializerAttribute() + attr.TileAttribute(multipliers) + tosa_fb.addOperator( + ts.TosaOp.Op().TILE, + [reshaped.name], + [tiled.name], + attr, + ) + elif isinstance(tosa_spec, Tosa_1_00): + multiple_shapes = tosa_fb.addConst( + (len(multipliers),), + ts.DType.SHAPE, + multipliers, + name=f"{node.name}_multiples", + ) + + tosa_fb.addOperator( + ts.TosaOp.Op().TILE, + [reshaped.name, multiple_shapes.name], + [tiled.name], + None, + ) + + broadcast_tensors.append(tiled) + + return broadcast_tensors + + +def build_reshape_tosa_1_0( + tosa_graph, input_name, new_shape, output_name, shape_name_override="" +): import serializer.tosa_serializer as ts_ # type: ignore shape = tosa_graph.addConst( np.array(new_shape).shape, ts_.DType.SHAPE, np.array(new_shape), - name=output_name + "_shape", + name=shape_name_override if shape_name_override else output_name + "_shape", ) attr = ts_.TosaSerializerAttribute() diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index a0de747cf3f..a85cc0ca925 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -12,6 +12,7 @@ load( "CXX", ) load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib") +load("@fbcode_macros//build_defs:cpp_python_extension.bzl", "cpp_python_extension") oncall("odai_jarvis") @@ -275,7 +276,6 @@ python_library( "//executorch/exir/passes:spec_prop_pass", ], ) - python_library( name = "decompose_ops", srcs = [ @@ -293,6 +293,14 @@ python_library( ], ) +python_library( + name = "typing_stubs", + srcs = [ + "typing_stubs.py", + ], + typing = True, +) + python_unittest( name = "test_graph_builder", @@ -321,6 +329,7 @@ python_unittest( deps = [ "fbsource//third-party/pypi/parameterized:parameterized", ":compiler", + ":typing_stubs", ":replace_ops", "//caffe2:torch", "//executorch/backends/cadence/aot:compiler", @@ -344,6 +353,7 @@ python_unittest( ":compiler", ":decompose_ops", "//caffe2:torch", + ":typing_stubs", "//executorch/backends/cadence/aot:compiler", "//executorch/backends/cadence/aot:graph_builder", "//executorch/backends/cadence/aot:pass_utils", @@ -363,6 +373,7 @@ python_unittest( deps = [ "fbsource//third-party/pypi/parameterized:parameterized", ":compiler", + ":typing_stubs", "//caffe2:torch", "//executorch/backends/cadence/aot:compiler", "//executorch/backends/cadence/aot:fuse_ops", @@ -384,6 +395,7 @@ python_unittest( deps = [ "fbsource//third-party/pypi/parameterized:parameterized", "fbsource//third-party/pypi/pyre-extensions:pyre-extensions", + ":typing_stubs", ":compiler", "//caffe2:torch", "//executorch/backends/cadence/aot:compiler", @@ -404,6 +416,7 @@ python_unittest( supports_static_listing = False, typing = True, deps = [ + ":typing_stubs", "fbsource//third-party/pypi/parameterized:parameterized", "//caffe2:torch", "//executorch/backends/cadence/aot:compiler", @@ -435,6 +448,22 @@ python_unittest( ], ) +python_library( + name = "memory_planning_algo", + srcs = [ + "memory_planning_algo.py", + ], + deps = [ + ":memory_constraints", + ":pass_utils", + "//executorch/exir:lib", + "//executorch/exir:memory_planning", + "//executorch/exir:tensor", + "//executorch/exir/passes:lib", + "fbsource//third-party/pypi/tabulate:tabulate", + ], +) + python_library( name = "memory_planning", srcs = [ @@ -443,6 +472,7 @@ python_library( deps = [ "fbsource//third-party/pypi/tabulate:tabulate", ":memory_constraints", + ":memory_planning_algo", ":pass_utils", "//caffe2:torch", "//executorch/exir:lib", @@ -477,6 +507,7 @@ python_unittest( deps = [ ":compiler", ":memory_planning", + ":typing_stubs", ":ops_registrations", ":pass_utils", "//caffe2:torch", diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index cb91c459cfd..560b625e4c0 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -278,6 +278,7 @@ def quantize_and_export_to_edge( dump_graphs: bool = False, constant_methods: Optional[dict[str, object]] = None, calibration_data: Optional[list[tuple[object, ...]]] = None, + core_aten_exceptions: Optional[list[torch._ops.OpOverload]] = None, ) -> EdgeProgramManager: """ Trace, quantize and lower a model/inputs pair to edge IR. @@ -294,6 +295,7 @@ def quantize_and_export_to_edge( quantized_model, dump_graphs=dump_graphs, constant_methods=constant_methods, + core_aten_exceptions=core_aten_exceptions, ) diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index aaf7f051b09..5c7f10729cc 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -712,32 +712,14 @@ def _create_requantize_node( out_dtype: torch.dtype, graph: torch.fx.Graph, ) -> torch.fx.Node: - in_scale_tensor = graph.call_function( - exir_ops.edge.aten.full.default, args=((1,), in_scale) - ) - in_zero_point_tensor = graph.call_function( - exir_ops.edge.aten.full.default, - args=((1,), in_zero_point), - kwargs={"dtype": torch.int32}, - ) - out_scale_tensor = graph.call_function( - exir_ops.edge.aten.full.default, args=((1,), out_scale) - ) - out_zero_point_tensor = graph.call_function( - exir_ops.edge.aten.full.default, - args=((1,), out_zero_point), - kwargs={"dtype": torch.int32}, - ) - # cadence::requantize(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype) -> Tensor Y - # TODO(hardiksharma): Add support for per-tensor requantize. return graph.call_function( - exir_ops.edge.cadence.requantize.default, + exir_ops.edge.cadence.requantize.per_tensor, args=( in_tensor, - in_scale_tensor, - in_zero_point_tensor, - out_scale_tensor, - out_zero_point_tensor, + in_scale, + in_zero_point, + out_scale, + out_zero_point, out_dtype, ), ) diff --git a/backends/cadence/aot/memory_planning.py b/backends/cadence/aot/memory_planning.py index 3c6c518f16a..5a7f6e936fb 100644 --- a/backends/cadence/aot/memory_planning.py +++ b/backends/cadence/aot/memory_planning.py @@ -4,20 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import collections import itertools import logging -import math -import typing -from functools import partial -from typing import Iterable, List, Optional, Set, Tuple +from typing import Callable, Iterable, List, Optional, Set, Tuple, TypeAlias import torch -from executorch.backends.cadence.aot.memory_constraints import ( - GenerateMemConstraints, - MemConstraints, +from executorch.backends.cadence.aot.memory_constraints import MemConstraints +from executorch.backends.cadence.aot.memory_planning_algo import ( + get_aligned_offset, + MemoryPlanningAlgo, + MemoryPlanningState, ) from executorch.backends.cadence.aot.utils import MemoryConfig @@ -30,26 +29,6 @@ from torch.fx.passes.infra.pass_base import PassResult -# get num memories indexed from 1..N, compatible with EXIR's spec.mem_id -def get_num_memories(memory_config: MemoryConfig) -> int: - return len(memory_config.memory_sizes) + 1 - - -# memory_space module provides num_memories indexed 0..num_memories-1. -def get_size(memory_config: MemoryConfig, exir_id: int) -> int: - return memory_config.memory_sizes[exir_id - 1] - - -def get_alignment(memory_config: MemoryConfig, exir_id: int) -> int: - # EXIR's spec.mem_id is indexed from 1..N. - assert memory_config.memory_alignments is not None - return memory_config.memory_alignments[exir_id - 1] - - -def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int: - return int(math.ceil(pre_aligned_offset / alignment) * alignment) - - def collect_specs_from_graph_module( graph_module: torch.fx.GraphModule, graph_signature: ExportGraphSignature, @@ -69,198 +48,127 @@ def collect_specs_from_graph_module( ) -# baseline tensor placement algorithm, that greedily tries to place the tensor in -# the fastest memory available -# flake8: noqa 'position_based_greedy_with_hierarchy' is too complex (13) -def position_based_greedy_with_hierarchy( - alignment: int, - specs: Set[TensorSpec], - graph_module: torch.fx.GraphModule, - graph_signature: ExportGraphSignature, - extra_padding: int = 0, - *, - memory_config: MemoryConfig, - mem_constraints: MemConstraints, - additional_constraint_gen_passes: Optional[ - List[ - typing.Callable[ - [MemConstraints], - typing.Callable[[torch.fx.GraphModule], Optional[PassResult]], - ] - ] - ] = None, -) -> List[int]: - # We do not use the `alignment` parameter and instead use the per-memory alignment - # constraints from `memory_config`. - del alignment - - num_memories = get_num_memories(memory_config) - bufsizes = [0] * num_memories - allocated_buffers: List[List[TensorSpec]] = [[] for _ in range(num_memories)] +class PositionBasedGreedyWithHierarchy(MemoryPlanningAlgo): + """Greedily place tensor in the fastest memory available.""" - # Generate the memory constraints - GenerateMemConstraints(mem_constraints, additional_constraint_gen_passes)( - graph_module - ) - - def overlap(spec: TensorSpec) -> Optional[TensorSpec]: - for allocated_spec in allocated_buffers[spec.mem_id]: - if Verifier.lifetime_overlap( - spec, allocated_spec - ) and Verifier.storage_overlap(spec, allocated_spec): - return allocated_spec - return None - - def memory_available(spec: TensorSpec) -> bool: - return get_aligned_offset( - spec.mem_offset + spec.allocated_memory, - get_alignment(memory_config, spec.mem_id), - ) <= get_size(memory_config, spec.mem_id) - - # Iterate over all the specs in sorted order - for spec in sorted( - specs, - key=lambda spec: spec.allocated_memory, - reverse=True, - ): - # Skip allocation memory to any tensor whose spec id is in skip list. - if mem_constraints.skipped_spec(spec): - continue - - for spec.mem_id in range(1, num_memories): - if mem_constraints.is_mem_id_in_blocklist(spec, spec.mem_id): - continue + def plan_spec(self, spec: TensorSpec, state: MemoryPlanningState) -> None: + """ + Greedily place the spec in the first memory that can fit it. + """ + for spec.mem_id in range(1, self.get_num_memories()): spec.mem_offset = 0 - while memory_available(spec) and (overlapped := overlap(spec)): + while self.is_valid_placement(spec) and ( + overlapped := state.get_overlapping_spec(spec) + ): + # Found an overlapping spec, so we need to adjust the offset = end of the overlapping spec + alignment. spec.mem_offset = get_aligned_offset( overlapped.mem_offset + overlapped.allocated_memory, - get_alignment(memory_config, spec.mem_id), + self.get_alignment(spec.mem_id), ) - if memory_available(spec): - allocated_buffers[spec.mem_id].append(spec) - bufsizes[spec.mem_id] = max( - spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id] - ) - break - if ( - not allocated_buffers[spec.mem_id] - or allocated_buffers[spec.mem_id][-1] is not spec - ): - raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") - - # And now honor the various memory location constraints (i.e., infer the memory - # location of tensors in skip_specs from the constraints) for this spec. - if mem_constraints.relative_loc_constraints_exist(): - mem_constraints.resolve_relative_loc_constraints(spec) - - # At the end, all the keys in relative_loc_constraints should have been visited - # and emptied. - assert not mem_constraints.relative_loc_constraints_exist() - - logging.debug( - f"position based greedy algorithm with hierarchy returns bufsizes: {bufsizes}" - ) - return bufsizes + if self.is_valid_placement(spec): + # Found a valid `spec.mem_offset` which is both valid and has no overlap. + state.place_spec(spec) + break -# Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf -def greedy_by_size_for_offset_calculation_with_hierarchy( - alignment: int, - specs: Set[TensorSpec], - graph_module: torch.fx.GraphModule, - graph_signature: ExportGraphSignature, - extra_padding: int = 0, - *, - memory_config: MemoryConfig, - mem_constraints: MemConstraints, - additional_constraint_gen_passes: Optional[ - List[ - typing.Callable[ - [MemConstraints], - typing.Callable[[torch.fx.GraphModule], Optional[PassResult]], - ] - ] - ] = None, -) -> List[int]: - # We do not use the `alignment` parameter and instead use the per-memory alignment - # constraints from `memory_config`. - del alignment + def plan( + self, + specs: Set[TensorSpec], + graph_module: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + extra_padding: int = 0, + prev_state: Optional[MemoryPlanningState] = None, + ) -> MemoryPlanningState: + state = prev_state or MemoryPlanningState(self.memory_config) + + # Iterate over all the specs in sorted order + for spec in sorted( + specs, + key=lambda spec: spec.allocated_memory, + reverse=True, + ): + self.plan_spec(spec, state) + if not state.is_placed(spec): + raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") - num_memories = get_num_memories(memory_config) - bufsizes = [0] * num_memories - allocated_buffers = [[] for _ in range(num_memories)] + return state - # Generate the memory constraints - GenerateMemConstraints(mem_constraints, additional_constraint_gen_passes)( - graph_module - ) - # Iterate over all the specs in sorted order - for spec in sorted( - specs, - key=lambda spec: spec.allocated_memory, - reverse=True, - ): - # Skip allocation memory to any tensor whose spec id is in skip list. - if mem_constraints.skipped_spec(spec): - continue +class GreedyWithHeuristic(MemoryPlanningAlgo): + """Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf.""" - for spec.mem_id in range(1, num_memories): - if mem_constraints.is_mem_id_in_blocklist(spec, spec.mem_id): - continue + def plan_spec(self, spec: TensorSpec, state: MemoryPlanningState) -> None: + """ + Greedily place the spec in the first memory that can fit it. + """ + for spec.mem_id in range(1, self.get_num_memories()): prev_offset, smallest_gap = 0, float("inf") - for allocated_spec in allocated_buffers[spec.mem_id]: - if Verifier.lifetime_overlap(spec, allocated_spec): - if ( - gap := allocated_spec.mem_offset - prev_offset - ) >= spec.allocated_memory and gap < smallest_gap: - smallest_gap = gap - spec.mem_offset = prev_offset - # Note that different from the paper, which updates prev_offset for all - # allocated tensors, we only update tensors with overlapping lifetime. - # Updating prev_offset outside the if statement will include tensors without - # overlapping lifetime, causing unnecessary waste of memory and make the - # calculation of gap incorrect. Moving it out will make the algorithm degenerate - # to the naive one, reusing 0 tensor. The paper may have a typo here. - prev_offset = max( - get_aligned_offset( - allocated_spec.mem_offset + allocated_spec.allocated_memory, - get_alignment(memory_config, spec.mem_id), - ), - prev_offset, - ) + for allocated_spec in state.allocated_buffers[spec.mem_id]: + if not Verifier.lifetime_overlap(spec, allocated_spec): + continue + + if ( + gap := allocated_spec.mem_offset - prev_offset + ) >= spec.allocated_memory and gap < smallest_gap: + smallest_gap = gap + spec.mem_offset = prev_offset + # Note that different from the paper, which updates prev_offset for all + # allocated tensors, we only update tensors with overlapping lifetime. + # Updating prev_offset outside the if statement will include tensors without + # overlapping lifetime, causing unnecessary waste of memory and make the + # calculation of gap incorrect. Moving it out will make the algorithm degenerate + # to the naive one, reusing 0 tensor. The paper may have a typo here. + prev_offset = max( + get_aligned_offset( + allocated_spec.mem_offset + allocated_spec.allocated_memory, + self.get_alignment(spec.mem_id), + ), + prev_offset, + ) if spec.mem_offset is None: if get_aligned_offset( prev_offset + spec.allocated_memory, - get_alignment(memory_config, spec.mem_id), - ) > get_size(memory_config, spec.mem_id): + self.get_alignment(spec.mem_id), + ) > self.get_size(spec.mem_id): continue else: spec.mem_offset = prev_offset - bufsizes[spec.mem_id] = max( - spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id] - ) - allocated_buffers[spec.mem_id].append(spec) - allocated_buffers[spec.mem_id].sort(key=lambda spec: spec.mem_offset) + + state.place_spec(spec) # A data structure used for maintaining the tensor order # by offset, named ordered_allocated_ids in the paper + state.allocated_buffers[spec.mem_id].sort(key=lambda spec: spec.mem_offset) break - if spec not in allocated_buffers[spec.mem_id]: - raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") - # And now honor the various memory location constraints (i.e., infer the memory - # location of tensors in skip_specs from the constraints) for this spec. - if mem_constraints.relative_loc_constraints_exist(): - mem_constraints.resolve_relative_loc_constraints(spec) + def plan( + self, + specs: set[TensorSpec], + graph_module: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + extra_padding: int = 0, + prev_state: Optional[MemoryPlanningState] = None, + ) -> MemoryPlanningState: + """Plan memory allocation for the given tensor specs.""" + # We do not use the `alignment` parameter and instead use the per-memory alignment + # constraints from `memory_config`. + + state = prev_state or MemoryPlanningState(self.memory_config) + + # Iterate over all the specs in sorted order + for spec in sorted( + specs, + key=lambda spec: spec.allocated_memory, + reverse=True, + ): + self.plan_spec(spec, state) + if not state.is_placed(spec): + raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") - # At the end, all the keys in relative_loc_constraints should have been visited - # and emptied. - assert not mem_constraints.relative_loc_constraints_exist() + logging.debug( + f"greedy by size for offset calculation with hierarchy returns bufsizes: {state.bufsizes}" + ) - logging.debug( - f"greedy by size for offset calculation with hierarchy returns bufsizes: {bufsizes}" - ) - return bufsizes + return state def find_peak_memory_usages_per_memory( @@ -436,6 +344,12 @@ def print_memory_planning_info( ) +ConstraintGenPassType: TypeAlias = Callable[ + [MemConstraints], + Callable[[torch.fx.GraphModule], Optional[PassResult]], +] + + class CadenceMemoryPlanning: def __init__( self, @@ -444,28 +358,48 @@ def __init__( mem_algo: int, alloc_graph_input: bool = True, alloc_graph_output: bool = True, - additional_constraint_gen_passes: Optional[ - List[ - typing.Callable[ - [MemConstraints], - typing.Callable[[torch.fx.GraphModule], Optional[PassResult]], - ] - ] - ] = None, + additional_constraint_gen_passes: Optional[list[ConstraintGenPassType]] = None, ) -> None: - self._init_mem_algos() - self.memory_config = memory_config self.opt_level = opt_level - self.mem_algo = mem_algo self.alloc_graph_input = alloc_graph_input self.alloc_graph_output = alloc_graph_output - self.additional_constraint_gen_passes = additional_constraint_gen_passes - def _init_mem_algos(self) -> None: - self.available_mem_algos = [ - position_based_greedy_with_hierarchy, - greedy_by_size_for_offset_calculation_with_hierarchy, + self.algo: MemoryPlanningAlgo = self.get_mem_algos( + memory_config, + opt_level, + alloc_graph_input, + alloc_graph_output, + additional_constraint_gen_passes, + )[mem_algo] + + @staticmethod + def get_mem_algos( + memory_config: MemoryConfig, + opt_level: int, + alloc_graph_input: bool, + alloc_graph_output: bool, + additional_constraint_gen_passes: Optional[list[ConstraintGenPassType]], + ) -> list[MemoryPlanningAlgo]: + return [ + PositionBasedGreedyWithHierarchy( + memory_config, + MemConstraints( + opt_level=opt_level, + alloc_graph_input=alloc_graph_input, + alloc_graph_output=alloc_graph_output, + ), + additional_constraint_gen_passes, + ), + GreedyWithHeuristic( + memory_config, + MemConstraints( + opt_level=opt_level, + alloc_graph_input=alloc_graph_input, + alloc_graph_output=alloc_graph_output, + ), + additional_constraint_gen_passes, + ), ] def __call__( @@ -479,22 +413,11 @@ def run( graph_module: torch.fx.GraphModule, graph_signature: Optional[ExportGraphSignature] = None, ) -> PassResult: - mem_constraints = MemConstraints( - opt_level=self.opt_level, - alloc_graph_input=self.alloc_graph_input, - alloc_graph_output=self.alloc_graph_output, - ) - algo = partial( - self.available_mem_algos[self.mem_algo], - memory_config=self.memory_config, - mem_constraints=mem_constraints, - additional_constraint_gen_passes=self.additional_constraint_gen_passes, - ) # Create the memory planning pass. We allocate memory for input # (output) tensors if alloc_graph_input (alloc_graph_output) is # True. mem_planning = MemoryPlanningPass( - algo, + self.algo, allow_lifetime_and_storage_overlap=(self.opt_level >= 2), alloc_graph_input=self.alloc_graph_input, alloc_graph_output=self.alloc_graph_output, diff --git a/backends/cadence/aot/memory_planning_algo.py b/backends/cadence/aot/memory_planning_algo.py new file mode 100644 index 00000000000..5b67cc6c5fd --- /dev/null +++ b/backends/cadence/aot/memory_planning_algo.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# pyre-strict + +import logging +import math +from abc import ABC, abstractmethod +from typing import Callable, Optional + +import torch +from executorch.backends.cadence.aot.memory_constraints import ( + GenerateMemConstraints, + MemConstraints, +) +from executorch.backends.cadence.aot.utils import MemoryConfig +from executorch.exir.memory_planning import Verifier +from executorch.exir.pass_base import PassResult +from executorch.exir.tensor import TensorSpec +from torch.export.exported_program import ExportGraphSignature + + +def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int: + return int(math.ceil(pre_aligned_offset / alignment) * alignment) + + +class MemoryPlanningState: + def __init__(self, memory_config: MemoryConfig) -> None: + self.num_memories: int = len(memory_config.memory_sizes) + 1 + alignment = memory_config.memory_alignments + assert alignment is not None + assert len(alignment) == self.num_memories - 1 + self.alignment: list[int] = [1] + alignment + # TODO: Maybe keep this sorted with heapq? + self.allocated_buffers: list[list[TensorSpec]] = [ + [] for _ in range(self.num_memories) + ] + self.bufsizes: list[int] = [0] * self.num_memories + + def place_spec(self, spec: TensorSpec) -> None: + """Place the spec at the given memory and offset.""" + assert self.get_overlapping_spec(spec) is None + self.allocated_buffers[spec.mem_id].append(spec) + self.bufsizes[spec.mem_id] = max( + self.bufsizes[spec.mem_id], + get_aligned_offset( + spec.mem_offset + spec.allocated_memory, self.alignment[spec.mem_id] + ), + ) + + def get_overlapping_spec(self, spec: TensorSpec) -> Optional[TensorSpec]: + """Get the overlapping spec for the given spec.""" + for allocated_spec in self.allocated_buffers[spec.mem_id]: + if Verifier.lifetime_overlap( + spec, allocated_spec + ) and Verifier.storage_overlap(spec, allocated_spec): + return allocated_spec + return None + + def is_placed(self, spec: TensorSpec) -> bool: + """Check if the spec is placed.""" + return spec in self.allocated_buffers[spec.mem_id] + + +class MemoryPlanningAlgo(ABC): + """Callable memory planning algorithm interface.""" + + def __init__( + self, + memory_config: MemoryConfig, + placement_constraints: MemConstraints, + additional_constraint_gen_passes: Optional[ + list[ + Callable[ + [MemConstraints], + Callable[[torch.fx.GraphModule], Optional[PassResult]], + ] + ] + ] = None, + ) -> None: + self.memory_config = memory_config + self.placement_constraints = placement_constraints + self.additional_constraint_gen_passes = additional_constraint_gen_passes + + def get_num_memories(self) -> int: + """Get num memories indexed from 1..N, compatible with EXIR's spec.mem_id.""" + return len(self.memory_config.memory_sizes) + 1 + + def get_size(self, exir_id: int) -> int: + # memory_space module provides num_memories indexed 0..num_memories-1. + return self.memory_config.memory_sizes[exir_id - 1] + + def get_alignment(self, exir_id: int) -> int: + # EXIR's spec.mem_id is indexed from 1..N. + assert self.memory_config.memory_alignments is not None + return self.memory_config.memory_alignments[exir_id - 1] + + def populate_constraints(self, graph_module: torch.fx.GraphModule) -> None: + """Populate the constraints for the memory planning algorithm.""" + GenerateMemConstraints( + mem_constraints=self.placement_constraints, + additional_constraint_gen_passes=self.additional_constraint_gen_passes, + )(graph_module) + + def is_valid_placement(self, spec: TensorSpec) -> bool: + return get_aligned_offset( + spec.mem_offset + spec.allocated_memory, + self.get_alignment(spec.mem_id), + ) <= self.get_size(spec.mem_id) + + @abstractmethod + def plan( + self, + specs: set[TensorSpec], + graph_module: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + extra_padding: int = 0, + prev_state: Optional[MemoryPlanningState] = None, + ) -> MemoryPlanningState: + """Plan memory allocation for the given tensor specs.""" + pass + + def __call__( + self, + alignment: int, + specs: set[TensorSpec], + graph_module: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + extra_padding: int = 0, + ) -> list[int]: + """Callable interface for ET memory planning.""" + self.populate_constraints(graph_module) + + # First plan the memory allocation for specs without relative constraints. + specs_without_relative_constraints = set( + filter( + lambda spec: not self.placement_constraints.skipped_spec(spec) + and not self.placement_constraints.is_mem_id_in_blocklist( + spec, spec.mem_id + ), + specs, + ) + ) + + # Call memory planning to get bufsizes. + state = self.plan( + specs_without_relative_constraints, + graph_module, + graph_signature, + extra_padding, + ) + + for spec in specs_without_relative_constraints: + # And now honor the various memory location constraints (i.e., infer the memory + # location of tensors in skip_specs from the constraints) for this spec. + self.placement_constraints.resolve_relative_loc_constraints(spec) + + # At the end, all the keys in relative_loc_constraints should have been visited + # and emptied. + assert not self.placement_constraints.relative_loc_constraints_exist() + + logging.debug(f"Memory planning algo found bufsizes: {state.bufsizes}") + return state.bufsizes diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index 996dfa43f8f..fe23ea73754 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -447,7 +447,7 @@ def call_operator( kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: - if op != exir_ops.edge.cadence.requantize.default: + if op != exir_ops.edge.cadence.requantize.per_tensor: return super().call_operator(op, args, kwargs, meta) # Parse the args diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index d78bdfeba6e..d85a0cc9be4 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2300,6 +2300,52 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: return result +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceMulTensorWithMulAndFullOpsPass(ExportPass): + """ + Extracts a single value argument of mul op to a separate full op. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for mul_node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.aten.mul.Tensor + ): + x_arg, const_arg = mul_node.args + + # Swap arguments if the order is wrong + if isinstance(const_arg, torch.fx.Node): + x_arg, const_arg = const_arg, x_arg + + # Skip if the const_arg is not a scalar + if not isinstance(const_arg, (float, int)) or not isinstance( + x_arg, torch.fx.Node + ): + continue + + # Cast the const_arg to the dtype of the x_arg + full_arg = self.resolve_full_arg(x_arg, const_arg) + + # Extract an argument to a separate full op. + with graph_module.graph.inserting_before(mul_node): + full_tensor = graph_module.graph.call_function( + exir_ops.edge.aten.full.default, args=([1], full_arg) + ) + new_mul_node = graph_module.graph.call_function( + torch.ops.aten.mul.Tensor, args=(x_arg, full_tensor) + ) + # Replace the old mul with a newly created mul. + mul_node.replace_all_uses_with(new_mul_node) + graph_module.graph.erase_node(mul_node) + return super().call(graph_module) + + def resolve_full_arg(self, x_arg, const_arg): + if x_arg.meta["val"].dtype == torch.float32 and isinstance(const_arg, int): + const_arg = float(const_arg) + if x_arg.meta["val"].dtype == torch.int32 and isinstance(const_arg, float): + const_arg = int(const_arg) + return const_arg + + # This class encapsulates all the functions that replace/switch one op in the # graph with another. class CadenceReplaceOpsInGraph: diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index 30ea91bafb5..ead8b46f775 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -4,11 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import unittest -from typing import Final, List, Tuple +from typing import cast, Final, List, Tuple import executorch.backends.cadence.aot.ops_registrations # noqa import torch @@ -26,10 +26,10 @@ ) from executorch.backends.cadence.aot.graph_builder import GraphBuilder from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match +from executorch.backends.cadence.aot.typing_stubs import expand from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.pass_base import ProxyValue -from parameterized import parameterized +from executorch.exir.pass_base import PassResult, ProxyValue from torch import nn @@ -43,7 +43,7 @@ def check_op_counts( class TestFusionPasses(TestFusionPassesBase): - def test_fuse_mm_with_add(self): + def test_fuse_mm_with_add(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32)) @@ -55,7 +55,9 @@ def test_fuse_mm_with_add(self): output = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z)) builder.output([output]) original_graph = builder.get_graph_module() - converted_graph = FuseMMWithAdd()(original_graph).graph_module + p = FuseMMWithAdd() + converted_graph = cast(PassResult, p(original_graph)).graph_module + converted_graph.graph.eliminate_dead_code() self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1 @@ -63,7 +65,7 @@ def test_fuse_mm_with_add(self): self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0) self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0) - def test_fuse_view_mm_view_add(self): + def test_fuse_view_mm_view_add(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(4, 8, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(2, 4, 6, dtype=torch.float32)) @@ -83,7 +85,9 @@ def test_fuse_view_mm_view_add(self): ) builder.output([output]) original_graph = builder.get_graph_module() - converted_graph = FuseMMWithAdd()(original_graph).graph_module + + p = FuseMMWithAdd() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1 @@ -91,7 +95,7 @@ def test_fuse_view_mm_view_add(self): self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0) self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0) - def test_keep_view_mm_view_add(self): + def test_keep_view_mm_view_add(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(4, 8, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(2, 4, 6, dtype=torch.float32)) @@ -112,7 +116,8 @@ def test_keep_view_mm_view_add(self): ) builder.output([output]) original_graph = builder.get_graph_module() - converted_graph = FuseMMWithAdd()(original_graph).graph_module + p = FuseMMWithAdd() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # Assert that mm and add were not fused to addmm, since z cannot be # broadcasted to the out of mm. @@ -122,7 +127,7 @@ def test_keep_view_mm_view_add(self): self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1) self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 1) - def test_fuse_mm_add_with_bias(self): + def test_fuse_mm_add_with_bias(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32)) @@ -136,7 +141,8 @@ def test_fuse_mm_add_with_bias(self): ) builder.output([output]) original_graph = builder.get_graph_module() - converted_graph = FuseMMWithAdd()(original_graph).graph_module + p = FuseMMWithAdd() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1 @@ -144,7 +150,7 @@ def test_fuse_mm_add_with_bias(self): self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 0) self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 0) - def test_keep_mm_add_with_multiple_users(self): + def test_keep_mm_add_with_multiple_users(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32)) @@ -161,7 +167,8 @@ def test_keep_mm_add_with_multiple_users(self): ) builder.output([output]) original_graph = builder.get_graph_module() - converted_graph = FuseMMWithAdd()(original_graph).graph_module + p = FuseMMWithAdd() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # Assert that mm and add were not fused to addmm, since add has multiple # users. @@ -171,17 +178,19 @@ def test_keep_mm_add_with_multiple_users(self): self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1) self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 3) - # TODO(matthiascremon): enable that pass with new flow + # TODO(matthiascremon) -> None: enable that pass with new flow @torch.no_grad() @unittest.expectedFailure - def test_legacy_conv_bn_fusion(self): + def test_legacy_conv_bn_fusion(self) -> None: class ModelConvBN(torch.nn.Module): - def __init__(self, in_features: int, out_features: int, kernel_size: int): + def __init__( + self, in_features: int, out_features: int, kernel_size: int + ) -> None: super().__init__() self.conv1d = nn.Conv1d(in_features, out_features, kernel_size) self.bn = nn.BatchNorm1d(out_features) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.conv1d(x) return self.bn(y) @@ -189,8 +198,7 @@ def forward(self, x): x = torch.randn(1, 64, 4) graph_module = ( - compiler.export_to_executorch(model.eval(), (x,)) - .exported_program() + compiler.export_to_executorch_gen_etrecord(model.eval(), (x,)) .exported_program() .graph_module ) @@ -207,7 +215,7 @@ def forward(self, x): 0, ) - def test_permute_transpose_fusion(self): + def test_permute_transpose_fusion(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32)) permute = builder.call_operator( @@ -217,11 +225,10 @@ def test_permute_transpose_fusion(self): op=exir_ops.edge.aten.transpose_copy.int, args=(permute, 1, 0), ) - builder.output(output) + builder.output([output]) original_graph = builder.get_graph_module() - converted_graph = FuseCascadedTransposeOrPermuteOps()( - original_graph - ).graph_module + p = FuseCascadedTransposeOrPermuteOps() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # Assert that permute op was fused with transpose op self.assertEqual( @@ -231,7 +238,7 @@ def test_permute_transpose_fusion(self): count_node(converted_graph, exir_ops.edge.aten.transpose_copy.int), 0 ) - def test_view_fusion(self): + def test_view_fusion(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32)) view1 = builder.call_operator( @@ -243,16 +250,17 @@ def test_view_fusion(self): output = builder.call_operator( op=exir_ops.edge.aten.view_copy.default, args=(view2, [1, 12, 10]) ) - builder.output(output) + builder.output([output]) original_graph = builder.get_graph_module() - converted_graph = FuseCascadedViewOps()(original_graph).graph_module + p = FuseCascadedViewOps() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # Assert that only one view op remains self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 1 ) - def test_view_fusion_branched(self): + def test_view_fusion_branched(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32)) y = builder.call_operator( @@ -266,14 +274,15 @@ def test_view_fusion_branched(self): ) builder.output([z, t]) original_graph = builder.get_graph_module() - converted_graph = FuseCascadedViewOps()(original_graph).graph_module + p = FuseCascadedViewOps() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # z and t should be fused and y should be eliminated. self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 2 ) - def test_force_quant_dequant_fusion(self): + def test_force_quant_dequant_fusion(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) quant = builder.call_operator( @@ -287,22 +296,21 @@ def test_force_quant_dequant_fusion(self): op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(permute, 4.5, 6, 0, 127, torch.int8), ) - builder.output(dequant) + builder.output([dequant]) original_graph = builder.get_graph_module() - converted_graph = FuseQuantDequantToRequantizePass( - force_quant_dequant_fusion=True - )(original_graph).graph_module + p = FuseQuantDequantToRequantizePass(force_quant_dequant_fusion=True) + converted_graph = cast(PassResult, p(original_graph)).graph_module self.check_op_counts( converted_graph, expected_op_counts={ # Verify that dequant/quant pair was replaced with requantize. exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, - exir_ops.edge.cadence.requantize.default: 1, + exir_ops.edge.cadence.requantize.per_tensor: 1, }, ) - def test_no_replace_quant_permute_dequant_with_requantize(self): + def test_no_replace_quant_permute_dequant_with_requantize(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) quant = builder.call_operator( @@ -316,11 +324,11 @@ def test_no_replace_quant_permute_dequant_with_requantize(self): op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(permute, 4.5, 6, 0, 127, torch.int8), ) - builder.output(dequant) + builder.output([dequant]) original_graph = builder.get_graph_module() - converted_graph = FuseQuantDequantToRequantizePass( - force_quant_dequant_fusion=False - )(original_graph).graph_module + + p = FuseQuantDequantToRequantizePass(force_quant_dequant_fusion=False) + converted_graph = cast(PassResult, p(original_graph)).graph_module self.check_op_counts( converted_graph, expected_op_counts={ @@ -328,11 +336,11 @@ def test_no_replace_quant_permute_dequant_with_requantize(self): # quantize -> permute -> dequantize should not be replaced with requantize. exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, - exir_ops.edge.cadence.requantize.default: 0, + exir_ops.edge.cadence.requantize.per_tensor: 0, }, ) - def test_replace_quant_view_dequant_with_requantize(self): + def test_replace_quant_view_dequant_with_requantize(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) quant = builder.call_operator( @@ -346,22 +354,21 @@ def test_replace_quant_view_dequant_with_requantize(self): op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(view, 4.5, 6, 0, 127, torch.int8), ) - builder.output(dequant) + builder.output([dequant]) original_graph = builder.get_graph_module() - converted_graph = FuseQuantDequantToRequantizePass()( - original_graph - ).graph_module + p = FuseQuantDequantToRequantizePass() + converted_graph = cast(PassResult, p(original_graph)).graph_module self.check_op_counts( converted_graph, expected_op_counts={ # Verify that dequant/quant pair was replaced with requantize. exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, - exir_ops.edge.cadence.requantize.default: 1, + exir_ops.edge.cadence.requantize.per_tensor: 1, }, ) - def test_replace_dequant_quant_with_requantize(self): + def test_replace_dequant_quant_with_requantize(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) dequant = builder.call_operator( @@ -372,22 +379,22 @@ def test_replace_dequant_quant_with_requantize(self): op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(dequant, 4.5, 6, 0, 127, torch.int8), ) - builder.output(quant) - graph_module = FuseQuantDequantToRequantizePass()( - builder.get_graph_module() - ).graph_module + builder.output([quant]) + original_graph = builder.get_graph_module() + p = FuseQuantDequantToRequantizePass() + converted_graph = cast(PassResult, p(original_graph)).graph_module self.check_op_counts( - graph_module, + converted_graph, expected_op_counts={ # Verify that dequant -> quant was replaced with requantize. exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, - exir_ops.edge.cadence.requantize.default: 1, + exir_ops.edge.cadence.requantize.per_tensor: 1, }, ) - def test_replace_dequant_permute_quant_with_requantize(self): + def test_replace_dequant_permute_quant_with_requantize(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) dequant = builder.call_operator( @@ -401,49 +408,49 @@ def test_replace_dequant_permute_quant_with_requantize(self): op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(permute, 4.5, 6, 0, 127, torch.int8), ) - builder.output(quant) - graph_module = FuseQuantDequantToRequantizePass()( - builder.get_graph_module() - ).graph_module + builder.output([quant]) + original_graph = builder.get_graph_module() + p = FuseQuantDequantToRequantizePass() + converted_graph = cast(PassResult, p(original_graph)).graph_module self.check_op_counts( - graph_module, + converted_graph, expected_op_counts={ # Verify that dequant -> permute -> quant was replaced with permute -> requantize. exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, exir_ops.edge.aten.permute_copy.default: 1, - exir_ops.edge.cadence.requantize.default: 1, + exir_ops.edge.cadence.requantize.per_tensor: 1, }, ) - def test_remove_nop_dequant_quant(self): - LEADING_DIMS: Final[int] = 12 - IN_DIM: Final[int] = 6 - OUT_DIM: Final[int] = 12 + def test_remove_nop_dequant_quant(self) -> None: + leading_dims = 12 + in_dim = 6 + out_dim = 12 builder = GraphBuilder() x = builder.placeholder( - "x", torch.randn(LEADING_DIMS, IN_DIM, dtype=torch.float32) + "x", torch.randn(leading_dims, in_dim, dtype=torch.float32) ) quant1 = builder.call_operator( op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(x, 4.5, 6, 0, 127, torch.int8), ) weights = builder.call_operator( - op=exir_ops.edge.aten.full.default, args=([OUT_DIM, IN_DIM], 1) + op=exir_ops.edge.aten.full.default, args=([out_dim, in_dim], 1) ) bias = builder.call_operator( - op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1) + op=exir_ops.edge.aten.full.default, args=([out_dim], 1) ) weight_zero_point = builder.call_operator( - op=exir_ops.edge.aten.full.default, args=([IN_DIM], 0) + op=exir_ops.edge.aten.full.default, args=([in_dim], 0) ) out_multiplier = builder.call_operator( - op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1) + op=exir_ops.edge.aten.full.default, args=([out_dim], 1) ) out_shift = builder.call_operator( - op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 0) + op=exir_ops.edge.aten.full.default, args=([out_dim], 0) ) linear1 = builder.call_operator( op=exir_ops.edge.cadence.quantized_linear.default, @@ -488,12 +495,12 @@ def test_remove_nop_dequant_quant(self): op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(linear2, 1.2, 3, 0, 127, torch.int8), ) - builder.output(dequant2) - graph_module = FuseQuantDequantToRequantizePass()( - builder.get_graph_module() - ).graph_module + builder.output([dequant2]) + original_graph = builder.get_graph_module() + p = FuseQuantDequantToRequantizePass() + converted_graph = cast(PassResult, p(original_graph)).graph_module self.check_op_counts( - graph_module, + converted_graph, expected_op_counts={ # Verify that one dequant/quant pair was removed from chain: # quant->linear->dequant->permute->quant->linear->dequant @@ -504,7 +511,7 @@ def test_remove_nop_dequant_quant(self): }, ) - def test_fuse_mul_into_dequant(self): + def test_fuse_mul_into_dequant(self) -> None: INPUT_SHAPE: Final[List[int]] = [4, 32] DEQUANT_SCALE: Final[float] = 1.5 FULL_VALUE: Final[float] = 3 @@ -523,14 +530,14 @@ def test_fuse_mul_into_dequant(self): op=exir_ops.edge.aten.mul.Tensor, args=(dequant, full), ) - builder.output(mul) - graph_module = FuseMulTensorIntoDequantPass()( - builder.get_graph_module() - ).graph_module + builder.output([mul]) + original_graph = builder.get_graph_module() + p = FuseMulTensorIntoDequantPass() + converted_graph = cast(PassResult, p(original_graph)).graph_module # verify that the mul and full ops were removed self.check_op_counts( - graph_module, + converted_graph, expected_op_counts={ exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, exir_ops.edge.aten.full.default: 0, @@ -539,7 +546,8 @@ def test_fuse_mul_into_dequant(self): ) # verify that the dequant scale value was updated correctly - for node in graph_module.graph.nodes: + deq_scale = -1 + for node in converted_graph.graph.nodes: if ( node.target == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default @@ -547,7 +555,7 @@ def test_fuse_mul_into_dequant(self): deq_scale = node.args[1] self.assertEqual(deq_scale, DEQUANT_SCALE * FULL_VALUE) - def test_fuse_mul_scalar_into_dequant(self): + def test_fuse_mul_scalar_into_dequant(self) -> None: dequant_scale = 0.006 mul_value = 0.3 @@ -565,14 +573,14 @@ def test_fuse_mul_scalar_into_dequant(self): op=exir_ops.edge.aten.mul.Scalar, args=(dequant, mul_value), ) - builder.output(mul_scalar) - graph_module = builder.get_graph_module() - - graph_module = FuseMulScalarIntoDequantPass()(graph_module).graph_module + builder.output([mul_scalar]) + original_graph = builder.get_graph_module() + p = FuseMulScalarIntoDequantPass() + converted_graph = cast(PassResult, p(original_graph)).graph_module # verify that the mul and full ops were removed self.check_op_counts( - graph_module, + converted_graph, expected_op_counts={ exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, exir_ops.edge.aten.mul.Scalar: 0, @@ -580,7 +588,8 @@ def test_fuse_mul_scalar_into_dequant(self): ) # verify that the dequant scale value was updated correctly - for node in graph_module.graph.nodes: + deq_scale = -1 + for node in converted_graph.graph.nodes: if ( node.target == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default @@ -588,7 +597,7 @@ def test_fuse_mul_scalar_into_dequant(self): deq_scale = node.args[1] self.assertEqual(deq_scale, dequant_scale * mul_value) - def test_fuse_mul_into_quant(self): + def test_fuse_mul_into_quant(self) -> None: quant_scale = 1.5 mul_value = 10 @@ -606,14 +615,14 @@ def test_fuse_mul_into_quant(self): op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(mul, quant_scale, 0, 0, 255, torch.uint8), ) - builder.output(quant) - graph_module = FuseMulTensorIntoQuantPass()( - builder.get_graph_module() - ).graph_module + builder.output([quant]) + original_graph = builder.get_graph_module() + p = FuseMulTensorIntoQuantPass() + converted_graph = cast(PassResult, p(original_graph)).graph_module # verify that the mul and full ops were removed self.check_op_counts( - graph_module, + converted_graph, expected_op_counts={ exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, exir_ops.edge.aten.full.default: 0, @@ -622,7 +631,8 @@ def test_fuse_mul_into_quant(self): ) # verify that the quant scale value was updated correctly - for node in graph_module.graph.nodes: + deq_scale = -1 + for node in converted_graph.graph.nodes: if ( node.target == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default @@ -630,7 +640,7 @@ def test_fuse_mul_into_quant(self): deq_scale = node.args[1] self.assertEqual(deq_scale, quant_scale * mul_value) - def test_fuse_then_transpose_pass(self): + def test_fuse_then_transpose_pass(self) -> None: # Create a graph with full -> transpose. builder = GraphBuilder() full_node = builder.call_operator( @@ -648,10 +658,10 @@ def test_fuse_then_transpose_pass(self): op=exir_ops.edge.aten.view_copy.default, args=(permute_node, (1, 6, 1)), ) - builder.output(view_node) - gm = builder.get_graph_module() + builder.output([view_node]) + original_graph = builder.get_graph_module() self.check_op_counts( - gm, + original_graph, expected_op_counts={ exir_ops.edge.aten.full.default: 1, exir_ops.edge.aten.transpose_copy.int: 1, @@ -661,7 +671,8 @@ def test_fuse_then_transpose_pass(self): ) # Check that the pass fuses the full with all other ops (transpose, permute, view). - gm_after_pass = FuseFullThenReshapePass()(gm).graph_module + p = FuseFullThenReshapePass() + gm_after_pass = cast(PassResult, p(original_graph)).graph_module self.check_op_counts( gm_after_pass, expected_op_counts={ @@ -708,7 +719,7 @@ def _create_operator( else: raise ValueError(f"Unsupported op: {op}") - @parameterized.expand( + @expand( [ # transpose -> quant -> same transpose => fuse ( @@ -858,7 +869,7 @@ def test_fuse_transpose_permute_pairs( quant_op: torch._ops.OpOverload, expected_is_fused: bool, dims: Tuple[int, int, int] = (2, 3, 4), - ): + ) -> None: # Create a graph with transpose/permute -> quant -> transpose/permute. builder = GraphBuilder() x = builder.placeholder("x", torch.randn(dims)) @@ -911,7 +922,7 @@ def test_fuse_transpose_permute_pairs( expected_op_counts=expected_op_counts, ) - def test_fusion_for_forked_transposes(self): + def test_fusion_for_forked_transposes(self) -> None: # Create a graph with # transpose -> quant -> transpose. # -> quant -> transpose. @@ -946,7 +957,8 @@ def test_fusion_for_forked_transposes(self): ) # Fuse all the transpose ops. - gm_after_pass = FuseTransposeOrPermuteOpPairsPass()(gm).graph_module + p = FuseTransposeOrPermuteOpPairsPass() + gm_after_pass = cast(PassResult, p(gm)).graph_module self.check_op_counts( gm_after_pass, expected_op_counts={ diff --git a/backends/cadence/aot/tests/test_memory_passes.py b/backends/cadence/aot/tests/test_memory_passes.py index b7616b047d3..73b0cba65ce 100644 --- a/backends/cadence/aot/tests/test_memory_passes.py +++ b/backends/cadence/aot/tests/test_memory_passes.py @@ -4,11 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import math import unittest -from typing import cast, Optional +from typing import cast, List, Optional import executorch.backends.cadence.aot.ops_registrations # noqa import torch @@ -19,6 +19,7 @@ find_peak_memory_usage, ) from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.backends.cadence.aot.typing_stubs import expand from executorch.backends.cadence.aot.utils import ( get_default_memory_config, MemoryConfig, @@ -27,7 +28,6 @@ from executorch.exir.memory_planning import collect_specs_from_nodes from executorch.exir.passes.spec_prop_pass import SpecPropPass from executorch.exir.tests.models import MultiLayerPerceptron -from parameterized.parameterized import parameterized from torch.fx import GraphModule @@ -224,11 +224,11 @@ def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None: # GenerateSliceAndSelectNopConstraints, and GenerateCatNopConstraints passes. def run_memory_planning( self, - original, - opt_level=2, - mem_algo=1, # greedy_by_size_for_offset_calculation_with_hierarchy - alloc_graph_input=True, - alloc_graph_output=True, + original: GraphModule, + opt_level: int = 2, + mem_algo: int = 1, # greedy_by_size_for_offset_calculation_with_hierarchy + alloc_graph_input: bool = True, + alloc_graph_output: bool = True, memory_config: Optional[MemoryConfig] = None, ) -> GraphModule: if memory_config is None: @@ -242,7 +242,7 @@ def run_memory_planning( alloc_graph_output=alloc_graph_output, )(graph_module).graph_module - @parameterized.expand( + @expand( [ [ [3, 6], # x_shape @@ -259,7 +259,11 @@ def run_memory_planning( ] ) def test_optimize_cat_on_placeholders( - self, x_shape, y_shape, concat_dim, alloc_graph_input + self, + x_shape: List[int], + y_shape: List[int], + concat_dim: int, + alloc_graph_input: bool, ) -> None: concat_shape = [x_shape[concat_dim] + y_shape[concat_dim], x_shape[1]] builder = GraphBuilder() @@ -294,7 +298,12 @@ def test_optimize_cat_on_placeholders( # "add_add_cat_model" : cat(x + 123, y + 456) # "add_add_cat_add_model": cat(x + 123, y + 456) + 789 def get_graph_module( - self, model_name, x_shape, y_shape, concated_shape, concat_dim + self, + model_name: str, + x_shape: List[int], + y_shape: List[int], + concated_shape: List[int], + concat_dim: int, ) -> GraphModule: builder = GraphBuilder() x = builder.placeholder("x", torch.ones(*x_shape, dtype=torch.float32)) @@ -346,7 +355,7 @@ def get_graph_module( raise ValueError(f"Unknown model name {model_name}") - @parameterized.expand( + @expand( [ ( "outermost", @@ -363,10 +372,14 @@ def get_graph_module( 1, # concat dim ), ], - name_func=lambda f, _, param: f"{f.__name__}_{param.args[0]}", ) def test_cat_optimized( - self, _, x_shape, y_shape, concated_shape, concat_dim + self, + _, + x_shape: List[int], + y_shape: List[int], + concated_shape: List[int], + concat_dim: int, ) -> None: original = self.get_graph_module( "add_add_cat_model", x_shape, y_shape, concated_shape, concat_dim @@ -379,7 +392,7 @@ def test_cat_optimized( self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) self.verify_nop_memory_alloc(graph_module) - @parameterized.expand( + @expand( [ ( "non_outermost", @@ -389,10 +402,14 @@ def test_cat_optimized( 1, # concat dim ), ], - name_func=lambda f, _, param: f"{f.__name__}_{param.args[0]}", ) def test_cat_not_optimized( - self, _, x_shape, y_shape, concated_shape, concat_dim + self, + _, + x_shape: List[int], + y_shape: List[int], + concated_shape: List[int], + concat_dim: int, ) -> None: original = self.get_graph_module( "add_add_cat_model", x_shape, y_shape, concated_shape, concat_dim @@ -404,7 +421,7 @@ def test_cat_not_optimized( self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) self.verify_nop_memory_alloc(graph_module) - @parameterized.expand( + @expand( [ ( "aligned", @@ -423,10 +440,15 @@ def test_cat_not_optimized( 1, # expected cat nodes ), ], - name_func=lambda f, _, param: f"{f.__name__}_{param.args[0]}", ) def test_cat_not_graph_output( - self, _, x_shape, y_shape, concated_shape, concat_dim, expected_cat_nodes + self, + _, + x_shape: List[int], + y_shape: List[int], + concated_shape: List[int], + concat_dim: int, + expected_cat_nodes: int, ) -> None: original = self.get_graph_module( "add_add_cat_add_model", x_shape, y_shape, concated_shape, concat_dim @@ -493,13 +515,13 @@ def test_optimize_cat_with_slice(self) -> None: self.assertEqual(count_node(graph_module, exir_ops.edge.aten.slice.Tensor), 1) self.verify_nop_memory_alloc(graph_module) - @parameterized.expand( + @expand( [ (True,), # alloc_graph_input (False,), # alloc_graph_input ], ) - def test_optimize_cat_with_slice_infeasible(self, alloc_graph_input) -> None: + def test_optimize_cat_with_slice_infeasible(self, alloc_graph_input: bool) -> None: x_shape = [5, 6] y_shape = [3, 6] concated_shape = [8, 6] diff --git a/backends/cadence/aot/tests/test_pass_filter.py b/backends/cadence/aot/tests/test_pass_filter.py index 21b004d4942..9bfd71556bd 100644 --- a/backends/cadence/aot/tests/test_pass_filter.py +++ b/backends/cadence/aot/tests/test_pass_filter.py @@ -4,13 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import unittest - from copy import deepcopy +from typing import Callable, Dict + from executorch.backends.cadence.aot import pass_utils from executorch.backends.cadence.aot.pass_utils import ( ALL_CADENCE_PASSES, @@ -23,24 +24,26 @@ class TestBase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: # Before running each test, create a copy of _all_passes to later restore it after test. # This avoids messing up the original _all_passes when running tests. self._all_passes_original = deepcopy(ALL_CADENCE_PASSES) # Clear _all_passes to do a clean test. It'll be restored after each test in tearDown(). pass_utils.ALL_CADENCE_PASSES.clear() - def tearDown(self): + def tearDown(self) -> None: # Restore _all_passes to original state before test. pass_utils.ALL_CADENCE_PASSES = self._all_passes_original - def get_filtered_passes(self, filter_): + def get_filtered_passes( + self, filter_: Callable[[ExportPass], bool] + ) -> Dict[ExportPass, CadencePassAttribute]: return {cls: attr for cls, attr in ALL_CADENCE_PASSES.items() if filter_(cls)} # Test pass registration class TestPassRegistration(TestBase): - def test_register_cadence_pass(self): + def test_register_cadence_pass(self) -> None: pass_attr_O0 = CadencePassAttribute(opt_level=0) pass_attr_debug = CadencePassAttribute(opt_level=None, debug_pass=True) pass_attr_O1_all_backends = CadencePassAttribute( @@ -73,7 +76,7 @@ class DummyPass_Debug(ExportPass): # Test pass filtering class TestPassFiltering(TestBase): - def test_filter_none(self): + def test_filter_none(self) -> None: pass_attr_O0 = CadencePassAttribute(opt_level=0) pass_attr_O1_debug = CadencePassAttribute(opt_level=1, debug_pass=True) pass_attr_O1_all_backends = CadencePassAttribute( @@ -103,7 +106,7 @@ class DummyPass_O1_All_Backends(ExportPass): } self.assertEqual(O1_filter_passes, expected_passes) - def test_filter_debug(self): + def test_filter_debug(self) -> None: pass_attr_O1_debug = CadencePassAttribute(opt_level=1, debug_pass=True) pass_attr_O2 = CadencePassAttribute(opt_level=2) @@ -122,7 +125,7 @@ class DummyPass_O2(ExportPass): # chooses debug=False. self.assertEqual(debug_filter_passes, {DummyPass_O2: pass_attr_O2}) - def test_filter_all(self): + def test_filter_all(self) -> None: @register_cadence_pass(CadencePassAttribute(opt_level=1)) class DummyPass_O1(ExportPass): pass @@ -138,7 +141,7 @@ class DummyPass_O2(ExportPass): # passes with opt_level <= 0 self.assertEqual(debug_filter_passes, {}) - def test_filter_opt_level_None(self): + def test_filter_opt_level_None(self) -> None: pass_attr_O1 = CadencePassAttribute(opt_level=1) pass_attr_O2_debug = CadencePassAttribute(opt_level=2, debug_pass=True) diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 012f109f313..5fe2848be94 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -4,11 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import unittest -from typing import cast, Tuple +from typing import cast, List, Tuple import executorch.backends.cadence.aot.ops_registrations # noqa import torch @@ -34,21 +34,22 @@ RemoveZeroSizedCatArgsPass, RemoveZeroSizedConstantPadNd, ) +from executorch.backends.cadence.aot.typing_stubs import expand from executorch.exir.dialects._ops import ops as exir_ops -from parameterized.parameterized import parameterized from pyre_extensions import none_throws from torch.fx.passes.infra.pass_base import PassResult class TestRemoveOpsPasses(unittest.TestCase): - @parameterized.expand( + + @expand( [ [(1, 2, 3)], ] ) @torch.no_grad() - def test_remove_to_ops(self, shape: Tuple[int]): + def test_remove_to_ops(self, shape: Tuple[int]) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) x = builder.call_operator( @@ -69,7 +70,7 @@ def test_remove_to_ops(self, shape: Tuple[int]): 0, ) - @parameterized.expand( + @expand( [ [(7, 6, 5)], [(7, 6)], @@ -77,7 +78,7 @@ def test_remove_to_ops(self, shape: Tuple[int]): ] ) @torch.no_grad() - def test_remove_nop_add_op_pass(self, shape: Tuple[int]): + def test_remove_nop_add_op_pass(self, shape: Tuple[int]) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) zeros = builder.call_operator( @@ -101,7 +102,7 @@ def test_remove_nop_add_op_pass(self, shape: Tuple[int]): 0, ) - @parameterized.expand( + @expand( [ [(7, 6, 5)], [(7, 6)], @@ -109,7 +110,7 @@ def test_remove_nop_add_op_pass(self, shape: Tuple[int]): ] ) @torch.no_grad() - def test_remove_nop_mul_op_pass(self, shape: Tuple[int]): + def test_remove_nop_mul_op_pass(self, shape: Tuple[int]) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) zeros = builder.call_operator( @@ -133,13 +134,13 @@ def test_remove_nop_mul_op_pass(self, shape: Tuple[int]): 0, ) - @parameterized.expand( + @expand( [ [(1, 2, 3)], ] ) @torch.no_grad() - def test_remove_alias_copy(self, shape: Tuple[int]): + def test_remove_alias_copy(self, shape: Tuple[int]) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) alias = builder.call_operator( @@ -155,13 +156,13 @@ def test_remove_alias_copy(self, shape: Tuple[int]): 0, ) - @parameterized.expand( + @expand( [ [(1, 2, 3)], ] ) @torch.no_grad() - def test_remove_detach_copy(self, shape: Tuple[int]): + def test_remove_detach_copy(self, shape: Tuple[int]) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) detach = builder.call_operator( @@ -177,7 +178,7 @@ def test_remove_detach_copy(self, shape: Tuple[int]): 0, ) - @parameterized.expand( + @expand( [ [(1, 2, 3), (0, 0)], ] @@ -185,7 +186,7 @@ def test_remove_detach_copy(self, shape: Tuple[int]): @torch.no_grad() def test_remove_zero_sized_constant_pad_nd( self, shape: Tuple[int], padding: Tuple[int] - ): + ) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) pad = builder.call_operator( @@ -201,7 +202,7 @@ def test_remove_zero_sized_constant_pad_nd( 0, ) - def test_remove_expand(self): + def test_remove_expand(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn([2, 3, 5], dtype=torch.float32)) expand = builder.call_operator( @@ -216,7 +217,7 @@ def test_remove_expand(self): count_node(graph_after_passes, exir_ops.edge.aten.expand_copy.default), 0 ) - def test_remove_zero_arg_cat(self): + def test_remove_zero_arg_cat(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn([1, 0, 3, 5], dtype=torch.float32)) y = builder.placeholder("y", torch.randn([2, 0, 3, 5], dtype=torch.float32)) @@ -232,18 +233,19 @@ def test_remove_zero_arg_cat(self): count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 0 ) - def test_remove_clone(self): + def test_remove_clone(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn([3, 5], dtype=torch.float32)) clone = builder.call_operator(op=exir_ops.edge.aten.clone.default, args=(x,)) builder.output([clone]) original = builder.get_graph_module() - graph_after_passes = RemoveCloneOpPass()(original).graph_module + p = RemoveCloneOpPass() + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, torch.ops.aten.clone.default), 0 ) - def test_remove_contiguous(self): + def test_remove_contiguous(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn([3, 5], dtype=torch.float32)) contiguous = builder.call_operator( @@ -251,19 +253,20 @@ def test_remove_contiguous(self): ) builder.output([contiguous]) original = builder.get_graph_module() - graph_after_passes = RemoveContiguousOpPass()(original).graph_module + p = RemoveContiguousOpPass() + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, torch.ops.aten.contiguous.default), 0 ) - @parameterized.expand( + @expand( [ [(3, 5), [3, 5]], [(1,), [-1]], ] ) @torch.no_grad() - def test_remove_nop_view(self, shape, new_shape): + def test_remove_nop_view(self, shape: Tuple[int], new_shape: List[int]) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) view = builder.call_operator( @@ -278,7 +281,7 @@ def test_remove_nop_view(self, shape, new_shape): count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 0 ) - def test_remove_nop_slice(self): + def test_remove_nop_slice(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) slice_ = builder.call_operator( @@ -299,7 +302,7 @@ def test_remove_nop_slice(self): count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0 ) - def test_remove_nop_select_before_view(self): + def test_remove_nop_select_before_view(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) select = builder.call_operator( @@ -323,7 +326,7 @@ def test_remove_nop_select_before_view(self): count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) - def test_remove_nop_select_before_add(self): + def test_remove_nop_select_before_add(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32)) @@ -345,7 +348,7 @@ def test_remove_nop_select_before_add(self): count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) - def test_remove_nop_select_before_mul(self): + def test_remove_nop_select_before_mul(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32)) @@ -367,7 +370,7 @@ def test_remove_nop_select_before_mul(self): count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) - def test_remove_nop_select_before_div(self): + def test_remove_nop_select_before_div(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32)) @@ -389,7 +392,7 @@ def test_remove_nop_select_before_div(self): count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) - def test_remove_nop_quant_dequant(self): + def test_remove_nop_quant_dequant(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(8, 8)) q0 = builder.call_operator( @@ -441,7 +444,7 @@ def test_remove_nop_quant_dequant(self): 1, ) - def test_remove_nop_aten_linalg_vector_norm(self): + def test_remove_nop_aten_linalg_vector_norm(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 1, 128, dtype=torch.float32)) linalg_vector_norm = builder.call_operator( @@ -736,7 +739,7 @@ def test_remove_permutes_around_elemwise_ops_complicated_case(self) -> None: count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 4 ) - def test_remove_dequant_on_branch(self): + def test_remove_dequant_on_branch(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 8, 4, 6)) x = builder.call_operator(op=exir_ops.edge.aten.abs.default, args=(x,)) diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py index 3e64a0ecd7c..50f5ca32c47 100644 --- a/backends/cadence/aot/tests/test_reorder_ops_passes.py +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -4,10 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import unittest +from typing import cast import executorch.backends.cadence.aot.ops_registrations # noqa import torch @@ -29,10 +30,11 @@ SinkOpsCloserToUsePass, ) from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import PassResult class TestReorderPasses(unittest.TestCase): - def test_sink_dequantize(self): + def test_sink_dequantize(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(32, 6, dtype=torch.float32)) y = builder.placeholder("y", torch.randn(32, 6, dtype=torch.float32)) @@ -103,9 +105,10 @@ def test_sink_dequantize(self): op=exir_ops.edge.aten.cat.default, args=([abs_1, dequantize_per_tensor_1],), ) - builder.output(cat) + builder.output([cat]) original_graph = builder.get_graph_module() - converted_graph = SinkOpsCloserToUsePass()(original_graph).graph_module + p = SinkOpsCloserToUsePass() + converted_graph = cast(PassResult, p(original_graph)).graph_module # Expect the SinkDequant pass to move dequant(y) from above the relu to just below it self.assertTrue( @@ -123,7 +126,7 @@ def test_sink_dequantize(self): ), ) - def test_advance_branched_quantize(self): + def test_advance_branched_quantize(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(64, 3, dtype=torch.float32)) view = builder.call_operator( @@ -174,9 +177,8 @@ def test_advance_branched_quantize(self): ] ) original_graph = builder.get_graph_module() - graph_module = AdvanceQuantizeOpAboveDefInBranchPass()( - original_graph - ).graph_module + p = AdvanceQuantizeOpAboveDefInBranchPass() + graph_module = cast(PassResult, p(original_graph)).graph_module graph_module.graph.eliminate_dead_code() nodes = get_compute_nodes_in_gm(graph_module) # The quantize op should be hoisted to dominate the branch @@ -208,19 +210,20 @@ def test_advance_branched_quantize(self): ), 4, ) - graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module + p = FuseQuantDequantToRequantizePass() + graph_module = cast(PassResult, p(graph_module)).graph_module # We expect 3 dequant/quant pairs to be removed because they have matching params, # leaving a single dequant/quant pair that is then merged into a requantize op self.assertEqual( count_node( graph_module, - exir_ops.edge.cadence.requantize.default, + exir_ops.edge.cadence.requantize.per_tensor, ), 1, ) @torch.no_grad() - def test_advance_quantize(self): + def test_advance_quantize(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(16, 1, 6, 32, dtype=torch.float32)) weights = builder.placeholder( @@ -268,14 +271,13 @@ def test_advance_quantize(self): op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(quantized_linear, 0.01627226173877716, -7, -128, 127, torch.int8), ) - builder.output(dequantize_per_tensor) + builder.output([dequantize_per_tensor]) original_graph = builder.get_graph_module() - converted_graph = AdvanceQuantizeOpAboveDefInBranchPass()( - original_graph - ).graph_module - converted_graph = AdvanceQuantizeOpAboveDefChainPass()( - original_graph - ).graph_module + + p1 = AdvanceQuantizeOpAboveDefInBranchPass() + tmp_graph = cast(PassResult, p1(original_graph)).graph_module + p2 = AdvanceQuantizeOpAboveDefChainPass() + converted_graph = cast(PassResult, p2(tmp_graph)).graph_module # Assert that permute node is now the successor of the quant node. self.assertTrue( get_node_pos( @@ -284,7 +286,7 @@ def test_advance_quantize(self): < get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default) ) - def test_postpone_dequantize1(self): + def test_postpone_dequantize1(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 16, 32, 6, dtype=torch.float32)) weights = builder.placeholder( @@ -332,11 +334,10 @@ def test_postpone_dequantize1(self): op=exir_ops.edge.aten.permute_copy.default, args=(dequantize_per_tensor, [1, 0, 3, 2]), ) - builder.output(permute) + builder.output([permute]) original_graph = builder.get_graph_module() - converted_graph = PostponeDequantizeOpBelowUseChainPass()( - original_graph - ).graph_module + p = PostponeDequantizeOpBelowUseChainPass() + converted_graph = cast(PassResult, p(original_graph)).graph_module # Assert that dequant node is now the successor of the permute node. self.assertTrue( get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default) @@ -346,7 +347,7 @@ def test_postpone_dequantize1(self): ) ) - def test_postpone_dequantize_branched(self): + def test_postpone_dequantize_branched(self) -> None: builder = GraphBuilder() x = builder.placeholder( "x", torch.randint(0, 255, [1, 18, 3], dtype=torch.uint8) @@ -403,14 +404,13 @@ def test_postpone_dequantize_branched(self): ) builder.output([aten_mm_default, aten_mm_default_1, aten_mm_default_2]) original_graph = builder.get_graph_module() - graph_module = PostponeDequantizeOpBelowUseChainPass()( - original_graph - ).graph_module - graph_module.graph.eliminate_dead_code() + p = PostponeDequantizeOpBelowUseChainPass() + converted_graph = cast(PassResult, p(original_graph)).graph_module + converted_graph.graph.eliminate_dead_code() # Asset that the dequant node was split into 4, one per branch self.assertEqual( count_node( - graph_module, + converted_graph, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, ), 3, @@ -419,7 +419,7 @@ def test_postpone_dequantize_branched(self): # Assert that the dequant node is no longer the predecessor of the squeeze node self.assertTrue( nodes_not_connected_in_gm( - graph_module, + converted_graph, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.aten.squeeze_copy.dims, ), @@ -427,14 +427,14 @@ def test_postpone_dequantize_branched(self): # Assert that dequant node is not predecessor of slice (it should've been moved below slice) self.assertTrue( nodes_not_connected_in_gm( - graph_module, + converted_graph, exir_ops.edge.cadence.dequantize_per_tensor.default, exir_ops.edge.aten.slice_copy.Tensor, ), ) # 4d -> permute -> 4d -> view -> 3d - def test_permute3_view4_chains(self): + def test_permute3_view4_chains(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 1, 768)) aten_view_copy_default = builder.call_operator( @@ -453,14 +453,10 @@ def test_permute3_view4_chains(self): op=exir_ops.edge.aten.permute_copy.default, args=(aten_view_copy_default_1, [0, 1, 3, 2]), ) - builder.output( - aten_permute_copy_default_1, - ) + builder.output([aten_permute_copy_default_1]) original_graph = builder.get_graph_module() - # Performing transform - converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( - original_graph - ).graph_module + p = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # Assert the order becomes view, view, permute, permute nodes = get_compute_nodes_in_gm(converted_graph) @@ -471,7 +467,7 @@ def test_permute3_view4_chains(self): self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy) # 3d -> permute -> 3d -> view -> 4d - def test_permute4_view3_chains(self): + def test_permute4_view3_chains(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 1, 768)) aten_view_copy_default = builder.call_operator( @@ -490,14 +486,11 @@ def test_permute4_view3_chains(self): op=exir_ops.edge.aten.permute_copy.default, args=(aten_view_copy_default_1, [2, 1, 0]), ) - builder.output( - aten_permute_copy_default_1, - ) + builder.output([aten_permute_copy_default_1]) original_graph = builder.get_graph_module() - # Performing transform - converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( - original_graph - ).graph_module + + p = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # Assert the order becomes view, view, permute, permute @@ -511,7 +504,7 @@ def test_permute4_view3_chains(self): # Negative test case where the transform should not happen. # permute->4d->view->3d where the view not only removes the dimension whose # size is 1 (this is ok), but also changes the size of the dimensions (not ok). - def test_permute_view_chains_neg(self): + def test_permute_view_chains_neg(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(3, 1, 768)) aten_view_copy_default = builder.call_operator( @@ -530,14 +523,12 @@ def test_permute_view_chains_neg(self): op=exir_ops.edge.aten.permute_copy.default, args=(aten_view_copy_default_1, [2, 1, 0]), ) - builder.output( - aten_permute_copy_default_1, - ) + builder.output([aten_permute_copy_default_1]) original_graph = builder.get_graph_module() + # Performing transform (nothing should happen) - converted_graph = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()( - original_graph - ).graph_module + p = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() + converted_graph = cast(PassResult, p(original_graph)).graph_module converted_graph.graph.eliminate_dead_code() # Assert the order is still view, permute, view, permute diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 4ff84a296e8..6d12c991d6d 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -8,14 +8,14 @@ import operator import unittest -from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union +from typing import cast, List, Optional, Sequence, Tuple, Union import torch from executorch.backends.cadence.aot.graph_builder import ( GraphBuilder, single_op_builder, ) -from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match from executorch.backends.cadence.aot.replace_ops import ( ForceChannelLastForConvPass, MakeSliceAndCatDimOutermostPass, @@ -31,6 +31,7 @@ ReplaceLinearWithFullyConnectedOpPass, ReplaceMatmulWithTransposedMatmulPass, ReplaceMMWithAddMMPass, + ReplaceMulTensorWithMulAndFullOpsPass, ReplaceNopTransposeOrPermuteWithViewPass, ReplacePadWithCatPass, ReplacePermuteWithTransposePass, @@ -46,11 +47,11 @@ ReplaceTrivialConvWithLinear, ReplaceWhereWithFullArgsWithWhereScalar, ) + +from executorch.backends.cadence.aot.typing_stubs import expand from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass from executorch.exir.passes import dead_code_elimination_pass - -from parameterized.parameterized import parameterized from torch.fx.passes.infra.pass_base import PassResult @@ -58,9 +59,9 @@ class TestReplaceOpsPasses(unittest.TestCase): def assertTargetCountEqual( self, graph_module: torch.fx.GraphModule, - target: Union[Callable[..., Any], str], + target: torch.fx.node.Target, expected_count: int, - ): + ) -> None: """Helper function to check the number of nodes with a given target.""" actual_count = count_node(graph_module, target) self.assertEqual( @@ -72,13 +73,13 @@ def assertTargetCountEqual( def assertTargetCountsEqual( self, graph_module: torch.fx.GraphModule, - targets_and_counts: List[Tuple[Union[Callable[..., Any], str], int]], - ): + targets_and_counts: List[Tuple[torch.fx.node.Target, int]], + ) -> None: """Helper function to check the number of nodes of all types for a given target.""" for target, expected_count in targets_and_counts: self.assertTargetCountEqual(graph_module, target, expected_count) - @parameterized.expand( + @expand( [ ( "regular", @@ -95,7 +96,7 @@ def assertTargetCountsEqual( @torch.no_grad() def test_replace_matmul_with_transposed_matmul( self, - _, + _: str, x_shape: Tuple[int], y_shape: Tuple[int], ) -> None: @@ -131,7 +132,7 @@ def test_replace_matmul_with_transposed_matmul( 1, ) - @parameterized.expand( + @expand( [ ("2d", (3, 5), [0, 0]), # shape # padding ("3d", (20, 1, 80), [0, 0, 0]), # shape # padding @@ -140,7 +141,7 @@ def test_replace_matmul_with_transposed_matmul( @torch.no_grad() def test_replace_constant_pad_nd_with_slice( self, _, shape: Tuple[int], padding: Tuple[int] - ): + ) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) matmul = builder.call_operator( @@ -161,7 +162,7 @@ def test_replace_constant_pad_nd_with_slice( 0, ) - @parameterized.expand( + @expand( [ ["3d", (7, 5, 6), 1.23], ["2d", (7, 5), 2], @@ -171,7 +172,7 @@ def test_replace_constant_pad_nd_with_slice( @torch.no_grad() def test_add_replace_scalar_with_tensor_arg( self, _, shape: Tuple[int], other: float - ): + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -189,7 +190,7 @@ def test_add_replace_scalar_with_tensor_arg( 0, ) - @parameterized.expand( + @expand( [ ["3d", (7, 5, 6), 1.23], ["2d", (7, 5), 2], @@ -199,7 +200,7 @@ def test_add_replace_scalar_with_tensor_arg( @torch.no_grad() def test_sub_replace_scalar_with_tensor_arg( self, _, shape: Tuple[int], other: float - ): + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -217,7 +218,7 @@ def test_sub_replace_scalar_with_tensor_arg( 0, ) - @parameterized.expand( + @expand( [ ["3d", (7, 5, 6), 1.23], ["2d", (7, 5), 2], @@ -227,7 +228,7 @@ def test_sub_replace_scalar_with_tensor_arg( @torch.no_grad() def test_mul_replace_scalar_with_tensor_arg( self, _, shape: Tuple[int], other: float - ): + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -245,7 +246,7 @@ def test_mul_replace_scalar_with_tensor_arg( 0, ) - @parameterized.expand( + @expand( [ ["3d", (7, 5, 6), 1.23], ["2d", (7, 5), 2], @@ -258,7 +259,7 @@ def test_div_replace_scalar_with_tensor_arg( _, shape: Tuple[int], other: float, - ): + ) -> None: x = torch.randn(*shape) original_gm = single_op_builder( placeholders=(x,), @@ -276,7 +277,7 @@ def test_div_replace_scalar_with_tensor_arg( 0, ) - @parameterized.expand( + @expand( [ ["4d", (2, 3, 5, 6)], ["3d", (7, 6, 5)], @@ -287,7 +288,7 @@ def test_div_replace_scalar_with_tensor_arg( @torch.no_grad() def test_replace_functionally_equivalent_op_targets_relu( self, _, shape: Tuple[int] - ): + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -306,16 +307,26 @@ def test_replace_functionally_equivalent_op_targets_relu( 0, ) - @parameterized.expand( - [["split_linear_tensor", (50,), i, 0] for i in range(2, 7)] - + [["split_leading_dim", (10, 2, 3), i, 0] for i in range(2, 7)] - + [["split_trailing_dim", (3, 3, 6), i, 2] for i in range(2, 6)] - + [["split_middle_dim", (3, 5, 14, 2, 3), i, 2] for i in range(2, 7)] + @expand( + [ + ("split_linear_tensor_split_size_2", (50,), 2, 0), + ("split_linear_tensor_split_size_5", (50,), 5, 0), + ("split_linear_tensor_split_size_7", (50,), 7, 0), + ("split_leading_dim_split_size_2", (10, 2, 3), 2, 0), + ("split_leading_dim_split_size_5", (10, 2, 3), 5, 0), + ("split_leading_dim_split_size_7", (10, 2, 3), 7, 0), + ("split_trailing_dim_split_size_2", (3, 3, 6), 2, 2), + ("split_trailing_dim_split_size_4", (3, 3, 6), 4, 2), + ("split_trailing_dim_split_size_6", (3, 3, 6), 6, 2), + ("split_middle_dim_split_size_2", (3, 5, 14, 2, 3), 2, 2), + ("split_middle_dim_split_size_5", (3, 5, 14, 2, 3), 5, 2), + ("split_middle_dim_split_size_7", (3, 5, 14, 2, 3), 7, 2), + ] ) @torch.no_grad() def test_replace_functionally_equivalent_op_targets_unsafe_split( self, _, shape: Tuple[int], split_size: int, dim: int - ): + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -332,7 +343,7 @@ def test_replace_functionally_equivalent_op_targets_unsafe_split( count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor), 0, x ) - @parameterized.expand( + @expand( [ [(1, 8, 33), 8, 16, 3], [(1, 8, 33), 8, 16, 5, 2], @@ -355,7 +366,7 @@ def test_replace_transposed_conv_with_linear( depthwise: bool = False, bias_enabled: bool = True, channel_last: bool = False, - ): + ) -> None: transposed = True output_padding = [0] groups = in_channels if depthwise else 1 @@ -417,7 +428,7 @@ def test_replace_transposed_conv_with_linear( 0, ) - @parameterized.expand( + @expand( [ [(1, 8, 33), 8, 16, 3, 2, 4, 3, False, False, False], # # depthwise @@ -441,7 +452,7 @@ def test_replace_convolution_optional_args_with_concrete_args( depthwise: bool = False, bias_enabled: bool = True, channel_last: bool = False, - ): + ) -> None: transposed = True output_padding = [0] groups = in_channels if depthwise else 1 @@ -495,7 +506,7 @@ def test_replace_convolution_optional_args_with_concrete_args( 1, ) - @parameterized.expand( + @expand( [ [(1, 2, 3), [1, 1]], [ @@ -505,7 +516,7 @@ def test_replace_convolution_optional_args_with_concrete_args( ] ) @torch.no_grad() - def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]): + def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -524,7 +535,7 @@ def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]): ) @torch.no_grad() - def test_replace_repeat_with_cat(self): + def test_replace_repeat_with_cat(self) -> None: x = torch.randn([3, 5]) original_gm = single_op_builder( placeholders=(x,), @@ -542,7 +553,7 @@ def test_replace_repeat_with_cat(self): 0, ) - @parameterized.expand( + @expand( [ # x, mask [(1,)], @@ -561,7 +572,7 @@ def test_replace_masked_scalar_tensor_with_full( self, shape: Tuple[int], mask_shape: Union[Tuple[int, ...], None] = None, - ): + ) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) mask = builder.placeholder( @@ -601,7 +612,7 @@ def test_replace_masked_scalar_tensor_with_full( @torch.no_grad() def test_replace_scalar_tensor_with_full( self, - ): + ) -> None: original_gm = single_op_builder( placeholders=(), op=exir_ops.edge.aten.scalar_tensor.default, @@ -619,7 +630,7 @@ def test_replace_scalar_tensor_with_full( ) @torch.no_grad() - def test_replace_linear_with_fully_connected(self): + def test_replace_linear_with_fully_connected(self) -> None: shape, in_channels, out_channels = (1, 14), 14, 128 builder = GraphBuilder() x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) @@ -660,7 +671,7 @@ def test_replace_linear_with_fully_connected(self): 0, ) - @parameterized.expand( + @expand( [ [(4, 16, 256), 256, 512, True], [(7, 17, 12), 12, 34, False], @@ -669,7 +680,7 @@ def test_replace_linear_with_fully_connected(self): @torch.no_grad() def test_replace_addmm_with_linear( self, shape: Tuple[int], in_features: int, out_features: int, bias: bool - ): + ) -> None: M, K, N, alpha, beta = 14, 48, 24, 1.0, 1.0 builder = GraphBuilder() x = builder.placeholder("x", torch.randn(N, dtype=torch.float32)) @@ -703,7 +714,7 @@ def test_replace_addmm_with_linear( ) @torch.no_grad() - def test_replace_mm_with_addmm(self): + def test_replace_mm_with_addmm(self) -> None: M, K, N = 14, 48, 24 x = torch.randn([M, K]) y = torch.randn([K, N]) @@ -724,7 +735,7 @@ def test_replace_mm_with_addmm(self): 0, ) - @parameterized.expand( + @expand( [ # shape [(5, 1, 6, 7)], @@ -737,7 +748,9 @@ def test_replace_mm_with_addmm(self): ] ) @torch.no_grad() - def test_replace_squeeze_with_view(self, shape: Tuple[int], dim=None): + def test_replace_squeeze_with_view( + self, shape: Tuple[int], dim: Optional[int] = None + ) -> None: x = torch.randn(shape) if dim: original_gm = single_op_builder( @@ -769,7 +782,7 @@ def test_replace_squeeze_with_view(self, shape: Tuple[int], dim=None): 0, ) - @parameterized.expand( + @expand( [ # shape, dim to unsqueeze [(5, 6, 7), 0], @@ -779,7 +792,7 @@ def test_replace_squeeze_with_view(self, shape: Tuple[int], dim=None): ] ) @torch.no_grad() - def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int): + def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -803,7 +816,7 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar( self, in_features: int = 16, out_features: int = 16, - ): + ) -> None: src_zero_point = 0 out_zero_point = 0 builder = GraphBuilder() @@ -872,7 +885,7 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar_tuple_ self, in_features: int = 16, out_features: int = 16, - ): + ) -> None: src_zero_point = 0 out_zero_point = 0 builder = GraphBuilder() @@ -945,7 +958,7 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar_tuple_ ) @torch.no_grad() - def test_replace_conv1d_with_linear(self): + def test_replace_conv1d_with_linear(self) -> None: x = torch.randn(1, 96, 7) weights = torch.randn(192, 96, 7) bias = torch.randn(192) @@ -956,11 +969,12 @@ def test_replace_conv1d_with_linear(self): ) # First, replace the aten convolution with a cadence.convolution op p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass() - temp_graph = p1(original_gm).graph_module + temp_graph = cast(PassResult, p1(original_gm)).graph_module + # temp_graph = p1(original_gm).graph_module self.assertIsNotNone(temp_graph) p2 = ReplaceTrivialConvWithLinear() - graph_after_passes = p2(temp_graph).graph_module + graph_after_passes = cast(PassResult, p2(temp_graph)).graph_module # Assert that conv1d is trivially converted to linear self.assertEqual( @@ -978,7 +992,7 @@ def test_replace_conv1d_with_linear(self): ) @torch.no_grad() - def test_replace_conv2d_with_linear(self): + def test_replace_conv2d_with_linear(self) -> None: x = torch.randn(1, 96, 7, 7) weights = torch.randn(192, 96, 7, 7) bias = torch.randn(192) @@ -989,11 +1003,11 @@ def test_replace_conv2d_with_linear(self): ) # First, replace the aten convolution with a cadence.convolution op p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass() - temp_graph = p1(original_gm).graph_module + temp_graph = cast(PassResult, p1(original_gm)).graph_module self.assertIsNotNone(temp_graph) p2 = ReplaceTrivialConvWithLinear() - graph_after_passes = p2(temp_graph).graph_module + graph_after_passes = cast(PassResult, p2(temp_graph)).graph_module # Assert that conv2d is trivially converted to linear self.assertEqual( @@ -1011,7 +1025,7 @@ def test_replace_conv2d_with_linear(self): ) @torch.no_grad() - def test_replace_conv2d_with_im2row_and_linear(self): + def test_replace_conv2d_with_im2row_and_linear(self) -> None: x = torch.randn(1, 96, 47, 37) weights = torch.randn(192, 96, 7, 7) bias = torch.randn(192) @@ -1034,14 +1048,16 @@ def test_replace_conv2d_with_im2row_and_linear(self): count_node(graph_after_passes, exir_ops.edge.aten.linear.default), 1 ) - @parameterized.expand( + @expand( [ [(3, 1, 5), 1, 0], [(3, 4, 1), 2, -1], ] ) @torch.no_grad() - def test_replace_select_with_view(self, shape: Tuple[int], dim: int, index: int): + def test_replace_select_with_view( + self, shape: Tuple[int], dim: int, index: int + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -1058,7 +1074,7 @@ def test_replace_select_with_view(self, shape: Tuple[int], dim: int, index: int) count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1 ) - @parameterized.expand( + @expand( [ [(2, 1, 3, 1), 1, 3, torch.float32], [(2, 1, 5), 1, 0, torch.int64], @@ -1072,7 +1088,7 @@ def test_replace_nop_transpose_with_view( dim0: int, dim1: int, dtype: torch.dtype = torch.float32, - ): + ) -> None: if dtype == torch.float32: x = torch.randn(shape) else: @@ -1093,7 +1109,7 @@ def test_replace_nop_transpose_with_view( count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1 ) - @parameterized.expand( + @expand( [ # permutations that can be replaced by view [(3, 1, 3, 1, 4), (0, 2, 4, 1, 3)], @@ -1101,7 +1117,9 @@ def test_replace_nop_transpose_with_view( ] ) @torch.no_grad() - def test_replace_nop_permute_with_view(self, shape, dims): + def test_replace_nop_permute_with_view( + self, shape: Tuple[int], dims: Tuple[int] + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -1119,15 +1137,17 @@ def test_replace_nop_permute_with_view(self, shape, dims): count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1 ) - @parameterized.expand( + @expand( [ # permutations replaced by transpose - [(3, 4), [1, 0]], + [(3, 4), (1, 0)], [(3, 4, 6), (0, 2, 1)], ] ) @torch.no_grad() - def test_replace_permute_with_transpose(self, shape, dims): + def test_replace_permute_with_transpose( + self, shape: Tuple[int], dims: Tuple[int] + ) -> None: x = torch.randn(shape) original_gm = single_op_builder( placeholders=(x,), @@ -1148,7 +1168,7 @@ def test_replace_permute_with_transpose(self, shape, dims): @torch.no_grad() def test_replace_permute_with_transpose_nop( self, - ): + ) -> None: x = torch.randn(3, 4) original_gm = single_op_builder( placeholders=(x,), @@ -1166,7 +1186,7 @@ def test_replace_permute_with_transpose_nop( count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0 ) - def test_replace_aten_where_with_cadence(self): + def test_replace_aten_where_with_cadence(self) -> None: builder = GraphBuilder() cond = builder.placeholder("cond", torch.randn(4, 8)) aten_gt_scalar = builder.call_operator( @@ -1201,7 +1221,7 @@ def test_replace_aten_where_with_cadence(self): 1, ) - @parameterized.expand( + @expand( [ [(4, 8), (4, 8), (4, 8), 0.0, 1.0], [(8,), (4, 8), (8,), 0.0, 1.0], @@ -1209,8 +1229,13 @@ def test_replace_aten_where_with_cadence(self): ] ) def test_replace_aten_where_with_cadence_broadcast( - self, cond_shape, a_shape, b_shape, val1, val2 - ): + self, + cond_shape: Tuple[int], + a_shape: Tuple[int], + b_shape: Tuple[int], + val1: float, + val2: float, + ) -> None: # cond_shape, a_shape, b_shape, val1, val2 = builder = GraphBuilder() cond = builder.placeholder("cond", torch.randn(cond_shape)) @@ -1242,7 +1267,7 @@ def test_replace_aten_where_with_cadence_broadcast( 1, ) - def test_no_replace_aten_gelu_with_approximate_gelu(self): + def test_no_replace_aten_gelu_with_approximate_gelu(self) -> None: inputs = torch.randn(2, 1, 64) gm = single_op_builder( @@ -1264,7 +1289,7 @@ def test_no_replace_aten_gelu_with_approximate_gelu(self): 1, ) - def test_replace_split_with_sizes_with_slice(self): + def test_replace_split_with_sizes_with_slice(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 16, 8, 4)) split = builder.call_operator( @@ -1290,8 +1315,8 @@ def test_replace_split_with_sizes_with_slice(self): 2, ) - @parameterized.expand([[2], [3], [4]]) - def test_replace_pow_with_mul(self, exponent: int): + @expand([[2], [3], [4]]) + def test_replace_pow_with_mul(self, exponent: int) -> None: x = torch.randn(2, 1, 64) original_gm = single_op_builder( placeholders=(x,), @@ -1315,13 +1340,13 @@ def test_replace_pow_with_mul(self, exponent: int): exponent - 1, ) - @parameterized.expand( + @expand( [ [1], [1.5], ] ) - def test_replace_pow_with_mul_not_applied(self, exponent): + def test_replace_pow_with_mul_not_applied(self, exponent: float) -> None: x = torch.randn(2, 1, 64) original_gm = single_op_builder( placeholders=(x,), @@ -1349,7 +1374,7 @@ def test_replace_pow_with_mul_not_applied(self, exponent): class TestReplaceIm2rowWithViewPass(unittest.TestCase): - def test_no_replacement_for_conv(self): + def test_no_replacement_for_conv(self) -> None: # Create a graph with a single im2row node. x = torch.randn(1, 3, 224, 224) pad_value = torch.randn(1) @@ -1375,7 +1400,7 @@ def test_no_replacement_for_conv(self): count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 0 ) - def test_no_replace_for_dilation(self): + def test_no_replace_for_dilation(self) -> None: # Create a graph with a single im2row node. x = torch.randn(1, 3, 5, 7) pad_value = torch.randn(1) @@ -1400,7 +1425,7 @@ def test_no_replace_for_dilation(self): count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 0 ) - def test_replace_linear_like_conv(self): + def test_replace_linear_like_conv(self) -> None: # Create a graph with a single im2row node. in_h, in_w = 13, 15 x = torch.randn(1, 3, in_h, in_w) @@ -1454,7 +1479,7 @@ def create_conv1d_graphmodule( args=args, ) - def test_conv1d_default_channel_last(self): + def test_conv1d_default_channel_last(self) -> None: # Create a graph with a single convolution node. # Check if graph module is valid by running exportpass on it. gm = self.create_conv1d_graphmodule() @@ -1482,7 +1507,7 @@ def test_conv1d_default_channel_last(self): self.assertEqual(len(node.args), 8, f"{node=}") self.assertTrue(node.args[7]) - def test_conv1d_no_transpose_if_already_channel_last(self): + def test_conv1d_no_transpose_if_already_channel_last(self) -> None: gm = self.create_conv1d_graphmodule(channels_last=True) gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) @@ -1531,7 +1556,7 @@ def create_convolution_graph_module( args=args, ) - def test_convolution_default_channel_last(self): + def test_convolution_default_channel_last(self) -> None: # Create a graph with a single convolution node. # Check if graph module is valid by running exportpass on it. gm = self.create_convolution_graph_module() @@ -1559,7 +1584,7 @@ def test_convolution_default_channel_last(self): self.assertEqual(len(node.args), 8, f"{node=}") self.assertTrue(node.args[7]) - def test_no_transpose_if_already_channel_last(self): + def test_no_transpose_if_already_channel_last(self) -> None: gm = self.create_convolution_graph_module(channels_last=True) gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) @@ -1636,7 +1661,7 @@ def create_quantized_convolution_graph_module( args=args, ) - def test_quantized_convolution_default_channel_last(self): + def test_quantized_convolution_default_channel_last(self) -> None: # Create a graph with a single convolution node. gm = self.create_quantized_convolution_graph_module() self.assertEqual( @@ -1666,7 +1691,7 @@ def test_quantized_convolution_default_channel_last(self): self.assertEqual(len(node.args), 15, f"{node=}") self.assertTrue(node.args[14]) - def test_no_transpose_if_already_quantized_conv_channel_last(self): + def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None: # Create a graph with a single im2row node. gm = self.create_quantized_convolution_graph_module(channels_last=True) # Check if graph module is valid by running exportpass on it. @@ -1709,7 +1734,7 @@ def create_slice_graph( args=(x, slice_dim, slice_begin, slice_end), ) - def test_slice_no_transpose_if_already_outermost(self): + def test_slice_no_transpose_if_already_outermost(self) -> None: # Create a graph with a single slice node. gm = self.create_slice_graph((3, 224, 224), 0, 1, 2) # Check if graph module is valid by running exportpass on it. @@ -1717,7 +1742,8 @@ def test_slice_no_transpose_if_already_outermost(self): self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) # Apply replacement pass. - gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + p = MakeSliceAndCatDimOutermostPass() + gm_after_pass = cast(PassResult, p(gm)).graph_module # Assert that no transpose ops were added. self.assertEqual( @@ -1725,7 +1751,7 @@ def test_slice_no_transpose_if_already_outermost(self): 0, ) - def test_slice_no_transpose_if_outermost_dimensions_are_one(self): + def test_slice_no_transpose_if_outermost_dimensions_are_one(self) -> None: # Create a graph with a single slice node on second outermost dimension. gm = self.create_slice_graph((1, 3, 4, 6), 1, 1, 2) # Check if graph module is valid by running exportpass on it. @@ -1733,7 +1759,8 @@ def test_slice_no_transpose_if_outermost_dimensions_are_one(self): self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) # Apply replacement pass. - gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + p = MakeSliceAndCatDimOutermostPass() + gm_after_pass = cast(PassResult, p(gm)).graph_module # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -1742,7 +1769,7 @@ def test_slice_no_transpose_if_outermost_dimensions_are_one(self): 0, ) - def test_slice_insert_transpose(self): + def test_slice_insert_transpose(self) -> None: # Create a graph with a single slice node. gm = self.create_slice_graph((1, 3, 4, 6), 2, 1, 2) # Check if graph module is valid by running exportpass on it. @@ -1750,7 +1777,8 @@ def test_slice_insert_transpose(self): self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) # Apply replacement pass. - gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + p = MakeSliceAndCatDimOutermostPass() + gm_after_pass = cast(PassResult, p(gm)).graph_module # Assert that there are two transpose ops added. self.assertEqual( @@ -1770,7 +1798,7 @@ def create_cat_graph( args=(input_tensors, cat_dim), ) - def test_cat_no_transpose_if_already_outermost(self): + def test_cat_no_transpose_if_already_outermost(self) -> None: # Create a graph with a single slice node on second outermost dimension. gm = self.create_cat_graph(input_shapes=((1, 3, 5), (2, 3, 5)), cat_dim=0) # Check if graph module is valid by running exportpass on it. @@ -1778,7 +1806,8 @@ def test_cat_no_transpose_if_already_outermost(self): self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) # Apply replacement pass. - gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + p = MakeSliceAndCatDimOutermostPass() + gm_after_pass = cast(PassResult, p(gm)).graph_module # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -1787,7 +1816,7 @@ def test_cat_no_transpose_if_already_outermost(self): 0, ) - def test_cat_no_transpose_if_outermost_dimensions_are_one(self): + def test_cat_no_transpose_if_outermost_dimensions_are_one(self) -> None: # Create a graph with a single slice node on second outermost dimension. gm = self.create_cat_graph(input_shapes=((1, 1, 3, 5), (1, 2, 3, 5)), cat_dim=1) # Check if graph module is valid by running exportpass on it. @@ -1795,7 +1824,8 @@ def test_cat_no_transpose_if_outermost_dimensions_are_one(self): self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) # Apply replacement pass. - gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + p = MakeSliceAndCatDimOutermostPass() + gm_after_pass = cast(PassResult, p(gm)).graph_module # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -1804,7 +1834,7 @@ def test_cat_no_transpose_if_outermost_dimensions_are_one(self): 0, ) - def test_cat_insert_transpose(self): + def test_cat_insert_transpose(self) -> None: # Create a graph with a single slice node on second outermost dimension. gm = self.create_cat_graph( input_shapes=((1, 1, 3, 5), (1, 1, 3, 3)), cat_dim=-1 @@ -1814,7 +1844,8 @@ def test_cat_insert_transpose(self): self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) # Apply replacement pass. - gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module + p = MakeSliceAndCatDimOutermostPass() + gm_after_pass = cast(PassResult, p(gm)).graph_module # Assert that transpose ops were added to make cat on outermost dimension. self.assertEqual( @@ -1840,7 +1871,7 @@ def _get_slice_empty_gm(self) -> torch.fx.GraphModule: builder.output([cat]) return builder.get_graph_module() - def test_empty_slice(self): + def test_empty_slice(self) -> None: gm = self._get_slice_empty_gm() self.assertEqual( len( @@ -1858,7 +1889,8 @@ def test_empty_slice(self): ), 0, ) - updated_gm = ReplaceEmptyTensorsWithFullPass()(gm).graph_module + p = ReplaceEmptyTensorsWithFullPass() + updated_gm = cast(PassResult, p(gm)).graph_module self.assertEqual( len( updated_gm.graph.find_nodes( @@ -1875,3 +1907,32 @@ def test_empty_slice(self): ), 1, ) + + @expand( + [ + ("int", int(123)), + ("float", float(456.0)), + ], + ) + @torch.no_grad() + def test_extract_mul_argument_to_full( + self, _: str, value: Union[int, float] + ) -> None: + x = torch.randn(2, 1, 64) + gm = single_op_builder( + placeholders=(x,), + op=torch.ops.aten.mul.Tensor, + args=(x, value), + kwargs={}, + ) + p = ReplaceMulTensorWithMulAndFullOpsPass() + graph_after_passes = p.call(gm).graph_module + self.assertTrue( + op_counts_match( + graph_after_passes, + expected_op_counts={ + torch.ops.aten.mul.Tensor: 1, + exir_ops.edge.aten.full.default: 1, + }, + ) + ) diff --git a/backends/cadence/aot/tests/test_simplify_ops_passes.py b/backends/cadence/aot/tests/test_simplify_ops_passes.py index 195c0ff00ab..f26fe897e1e 100644 --- a/backends/cadence/aot/tests/test_simplify_ops_passes.py +++ b/backends/cadence/aot/tests/test_simplify_ops_passes.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import unittest @@ -18,13 +18,13 @@ BindOptionalArgsPass, SimplifySliceOpPass, ) +from executorch.backends.cadence.aot.typing_stubs import expand from executorch.exir.dialects._ops import ops as exir_ops -from parameterized.parameterized import parameterized from torch.fx.passes.infra.pass_base import PassResult class TestSimplifyOpsPasses(unittest.TestCase): - @parameterized.expand( + @expand( [ [(3, 16, 5), (3, 0, 5), 1, 15, 3, 3], ] @@ -38,7 +38,7 @@ def test_simplify_slice_scatter_op( start: Optional[int] = None, end: Optional[int] = None, step: int = 1, - ): + ) -> None: x = torch.randn(*in_shape) y = torch.randn(*src_shape) gm = single_op_builder( @@ -50,7 +50,7 @@ def test_simplify_slice_scatter_op( gm = cast(PassResult, p(gm)).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_scatter.default), 0) - @parameterized.expand( + @expand( [ [(3, 16, 5), 1, 15, 3, 3], ] @@ -63,7 +63,7 @@ def test_simplify_slice_op( start: Optional[int] = None, end: Optional[int] = None, step: int = 1, - ): + ) -> None: x = torch.randn(*in_shape) gm = single_op_builder( placeholders=(x,), diff --git a/backends/cadence/aot/typing_stubs.py b/backends/cadence/aot/typing_stubs.py new file mode 100644 index 00000000000..f15628f7948 --- /dev/null +++ b/backends/cadence/aot/typing_stubs.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable + + # This only runs during static type checking (not at runtime) + def expand(arg: object) -> Callable[..., None]: ... + +else: + # Real import used at runtime + # from parameterized.parameterized import parameterized.expand as expand # noqa + from parameterized.parameterized import parameterized + + expand = parameterized.expand diff --git a/backends/mediatek/README.md b/backends/mediatek/README.md index 665e11ce266..e8a535b3fde 100644 --- a/backends/mediatek/README.md +++ b/backends/mediatek/README.md @@ -14,23 +14,11 @@ The examples provided in this repository are tested and supported on the followi Before you begin, ensure you have the following prerequisites installed and configured: -#### 1. Buck2 Build Tool - -- **Download Buck2**: Obtain Buck2 from the official [releases page](https://github.com/facebook/buck2/releases/tag/2024-02-01). -- **Add to PATH**: Extract the downloaded file and add the directory to your system's `$PATH` environment variable. - ```bash - export PATH=:$PATH - ``` - -#### 2. Android NDK +#### 1. Android NDK - **Download Android NDK**: Acquire the Android NDK version 26.3.11579264 from the [Android developer site](https://developer.android.com/ndk/downloads). -- **Set NDK Path**: Ensure that the `$ANDROID_NDK` environment variable is set to the path where the NDK is located. - ```bash - export ANDROID_NDK= - ``` -#### 3. MediaTek ExecuTorch Libraries +#### 2. MediaTek ExecuTorch Libraries To get started with MediaTek's ExecuTorch libraries, download the [NeuroPilot Express SDK](https://neuropilot.mediatek.com/resources/public/npexpress/en/docs/npexpress) from MediaTek's NeuroPilot portal. The SDK includes the following components: @@ -60,26 +48,28 @@ Follow the steps below to setup your build environment: pip3 install mtk_neuron-8.2.19-py3-none-linux_x86_64.whl pip3 install mtk_converter-8.13.0+public-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ``` -- Set evironment variables for building backend - ```bash - export NEURON_BUFFER_ALLOCATOR_LIB= - ``` ### Build -1. Navigate to `scripts/` directory. - -2. **Build MediaTek Backend**: Once the prerequisites are in place, run the `mtk_build.sh` script to start the build process, MediaTek backend will be built under `cmake-android-out/backends/` as `libneuron_backend.so` +1. Copy `NeuronAdapter.h` to `backends/mediatek/runtime/include/api/` +2. Set NDK Path: Ensure that the `$ANDROID_NDK` environment variable is set to the path where the NDK is located. ```bash - ./mtk_build.sh + export ANDROID_NDK= ``` -### Run +3. Build the backend library `libneuron_backend.so`: + ```bash + cd backends/mediatek/scripts/ + ./mtk_build.sh + ``` +The output is `libneuron_backend.so` in `cmake-android-out/backends/mediatek/`. -1. **Push MediaTek universal SDK and MediaTek backend to the device**: push `libneuronusdk_adapter.mtk.so` and `libneuron_backend.so` to the phone and export it to the `$LD_LIBRARY_PATH` environment variable before executing ExecuTorch with MediaTek backend. +### Run +1. Push `libneuron_backend.so`, `libneuronusdk_adapter.mtk.so` and `libneuron_buffer_allocator.so` to the device. +2. Set the library path before running ExecuTorch: ```bash - export LD_LIBRARY_PATH=::$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=:::$LD_LIBRARY_PATH ``` Please refer to `executorch/examples/mediatek/` for export and execution examples of various of models. diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index ca1aa78ef17..01710aa8d80 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -19,6 +19,7 @@ from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm from .decompose_roll import DecomposeRoll from .decompose_silu import DecomposeSilu +from .decompose_wrap_with_autocast import DecomposeWrapWithAutocast from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape from .fixed_linear_keep_dim import FixedLinearKeepDim from .fold_qdq import FoldQDQ @@ -35,7 +36,6 @@ from .remove_0d_tensor import Remove0DTensor from .remove_redundancy import RemoveRedundancy from .replace_arange_args import ReplaceArangeArgs -from .replace_index_put_input import ReplaceIndexPutInput from .replace_inf_values import ReplaceInfValues from .tag_quant_io import TagQuantIO @@ -56,6 +56,7 @@ DecomposeLinalgVectorNorm, DecomposeRoll, DecomposeSilu, + DecomposeWrapWithAutocast, ExpandBroadcastTensorShape, FixedLinearKeepDim, FoldQDQ, @@ -72,7 +73,6 @@ Remove0DTensor, RemoveRedundancy, ReplaceArangeArgs, - ReplaceIndexPutInput, ReplaceInfValues, TagQuantIO, ] diff --git a/backends/qualcomm/_passes/decompose_wrap_with_autocast.py b/backends/qualcomm/_passes/decompose_wrap_with_autocast.py new file mode 100644 index 00000000000..6c073bd309c --- /dev/null +++ b/backends/qualcomm/_passes/decompose_wrap_with_autocast.py @@ -0,0 +1,88 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import _operator +from typing import Dict, Tuple + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import copy_nn_module_stack + + +class DecomposeWrapWithAutocast(ExportPass): + """ + Decompose the _higher_order_ops WrapWithAutocast + """ + + def __init__(self) -> None: + super().__init__() + + def _get_submod( + self, gm: torch.fx.GraphModule, node: torch.fx.Node + ) -> Tuple[torch.fx.GraphModule, str]: + for a in node.args: + if isinstance(a, torch.fx.Node) and "submod" in a.target: + return getattr(gm, a.target), a.target + + def _replace_output( + self, wwac_node: torch.fx.Node, output_node: torch.fx.Node, remap: Dict + ): + for user in wwac_node.users.copy(): + arg_idx = 0 + is_user_getitem = False + + if user.target == _operator.getitem: + arg_idx = user.args[1] + is_user_getitem = True + + user.replace_input_with( + wwac_node, + remap[output_node.args[0][arg_idx]], + ) + + if is_user_getitem: + for user_user in user.users.copy(): + user_user.replace_input_with(user, user.args[0]) + + def _replace(self, gm: torch.fx.GraphModule) -> None: + graph = gm.graph + for node in graph.nodes: + if isinstance(node.target, torch._higher_order_ops.wrap.WrapWithAutocast): + submod, submod_name = self._get_submod(gm, node) + n_args = node.args + input_submod = n_args[4] + decomposed_module = submod + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correctly updated in the new graph + # remap = {"expand_1": node.args[5], "to_4": node.args[6]} + remap = {n_args[i].name: n_args[i] for i in range(5, len(n_args))} + + for decomposed_node in decomposed_module.graph.nodes: + copy_nn_module_stack(node, decomposed_node) + # no need to copy existent 'output' + if decomposed_node.op == "output": + self._replace_output(node, decomposed_node, remap) + # no need to copy existent placeholders + elif decomposed_node.op == "placeholder": + # replace node map from string to graph node + remap[decomposed_node] = remap.pop(decomposed_node.name) + else: + remap[decomposed_node] = graph.node_copy( + decomposed_node, + arg_transform=lambda x, remap=remap: remap[x], + ) + + graph.erase_node(node) + + graph.erase_node(input_submod) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self._replace(graph_module) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/insert_io_qdq.py b/backends/qualcomm/_passes/insert_io_qdq.py index e5b15f2d12c..caecae64fa8 100644 --- a/backends/qualcomm/_passes/insert_io_qdq.py +++ b/backends/qualcomm/_passes/insert_io_qdq.py @@ -9,7 +9,10 @@ from executorch.backends.qualcomm.builders.node_visitor import q_ops -from executorch.backends.qualcomm.builders.utils import is_parameter +from executorch.backends.qualcomm.builders.utils import ( + is_mutable_buffer_input, + is_parameter, +) from executorch.backends.qualcomm.utils.constants import ( QCOM_ENCODING, QCOM_QUANT_ATTRS, @@ -124,7 +127,10 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: if ( n.op == "placeholder" and n.meta.get(QCOM_QUANT_ATTRS) - and not is_parameter(n, self.edge_program) + and ( + not is_parameter(n, self.edge_program) + or is_mutable_buffer_input(n, self.edge_program) + ) ): self._insert_quant_node( graph_module, n, n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index bb6a4dd0a67..8340fa6209e 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -24,6 +24,7 @@ DecomposeLinalgVectorNorm, DecomposeRoll, DecomposeSilu, + DecomposeWrapWithAutocast, ExpandBroadcastTensorShape, FixedLinearKeepDim, FoldQDQ, @@ -40,7 +41,6 @@ Remove0DTensor, RemoveRedundancy, ReplaceArangeArgs, - ReplaceIndexPutInput, ReplaceInfValues, TagQuantIO, ) @@ -80,7 +80,7 @@ def get_capture_program_passes(): (AnnotateQuantAttrs, True), (AnnotateStack, True), (AnnotateUnbind, True), - (ConvertBmmToMatmul, True), + (ConvertBmmToMatmul, False), (ConvertConv1dToConv2d, True), (DecomposeAny, True), (DecomposeColIm, True), @@ -93,7 +93,6 @@ def get_capture_program_passes(): (RecomposeRmsNorm, False), (Remove0DTensor, True), (RemoveRedundancy, True), - (ReplaceIndexPutInput, True), (TagQuantIO, False), ] @@ -194,6 +193,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoll()) self.add_pass(DecomposeSilu()) + self.add_pass(DecomposeWrapWithAutocast()) self.add_pass(DecomposeEinsum()) self.add_pass(DecomposeExpM1()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) @@ -207,6 +207,7 @@ def transform_for_export_pipeline(self, exported_program: ExportedProgram): self.add_pass(DecomposeRoll()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(DecomposeExpM1()) + self.add_pass(DecomposeWrapWithAutocast()) # this pass will rewrite state_dict, it needs to be accomplished before # to_edge_transform_and_lower self.add_pass(ConvertConv1dToConv2d(exported_program)) @@ -223,4 +224,11 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram): self.add_pass(LayoutTransform(exported_program, insert_permute=True)) self.add_pass(FuseConsecutiveCast()) self.add_pass(FuseConsecutiveTranspose()) - return self._transform(exported_program.graph_module) + self._transform(exported_program.graph_module) + # Update inputs_to_buffers and buffers_to_mutate in graph signature for mutable buffer + # Since I/O will be inserted Q/DQ, it results in failed to mapping output node names and buffer + exported_program._graph_signature = _get_updated_graph_signature( + exported_program.graph_signature, + exported_program.graph_module, + ) + return exported_program.graph_module diff --git a/backends/qualcomm/_passes/remove_redundancy.py b/backends/qualcomm/_passes/remove_redundancy.py index bff917be3da..22d476ef21b 100644 --- a/backends/qualcomm/_passes/remove_redundancy.py +++ b/backends/qualcomm/_passes/remove_redundancy.py @@ -43,6 +43,8 @@ def _dim_order_op_condition(self, node): dim_order = node.kwargs.get("dim_order") # skip if there contains layout hint # e.g. (0, 2, 3, 1) != (0, 1, 2, 3) + if node.meta["val"].dtype != node.args[0].meta["val"].dtype: + return False return dim_order != list(range(len(dim_order))) def _to_copy_op_condition(self, node): @@ -53,19 +55,15 @@ def _default_condition(self, ndoe): def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: - if n.target not in self.redundant_ops or not self.redundant_ops[n.target]( - n - ): - continue - - to_be_remove = n - # assert_tensor_metadata op has no user - if len(n.users.keys()) == 0: - n.args = () - # normal case - for user_n in list(n.users.keys()): - user_n.replace_input_with(n, n.args[0]) - graph_module.graph.erase_node(to_be_remove) + if n.target in self.redundant_ops and self.redundant_ops[n.target](n): + to_be_remove = n + # assert_tensor_metadata op has no user + if len(n.users.keys()) == 0: + n.args = () + # normal case + for user_n in list(n.users.keys()): + user_n.replace_input_with(n, n.args[0]) + graph_module.graph.erase_node(to_be_remove) def call(self, graph_module: torch.fx.GraphModule): self._remove(graph_module) diff --git a/backends/qualcomm/_passes/replace_index_put_input.py b/backends/qualcomm/_passes/replace_index_put_input.py deleted file mode 100644 index 93ee21bfc7c..00000000000 --- a/backends/qualcomm/_passes/replace_index_put_input.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import torch -from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING, QCOM_QUANT_ATTRS -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - - -class ReplaceIndexPutInput(ExportPass): - """ - Index put input workaround for quantized module - """ - - dq_q_map = { - # per tensor - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor: exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, - # per channel - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - } - - def __init__(self): - super(ReplaceIndexPutInput, self).__init__() - - def call(self, graph_module: torch.fx.GraphModule): - graph = graph_module.graph - for node in graph.nodes: - if node.target == exir_ops.edge.aten.index_put.default: - if ( - copy_node := list(node.users)[0] - ) and copy_node.target == exir_ops.edge.aten.copy.default: - m_buffer_node = copy_node.args[0] - dq_node = node.args[0] - bad_frozen_node = dq_node.args[0] - if QCOM_QUANT_ATTRS in bad_frozen_node.meta: - m_buffer_node.meta[QCOM_QUANT_ATTRS] = bad_frozen_node.meta[ - QCOM_QUANT_ATTRS - ] - m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] = ( - self.dq_q_map[ - m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] - ] - ) - with graph.inserting_after(dq_node): - node.replace_input_with(dq_node, m_buffer_node) - else: - continue - - graph.eliminate_dead_code() - graph_module.recompile() - return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/replace_inf_values.py b/backends/qualcomm/_passes/replace_inf_values.py index c7e475f54f2..bffcea03a72 100644 --- a/backends/qualcomm/_passes/replace_inf_values.py +++ b/backends/qualcomm/_passes/replace_inf_values.py @@ -9,13 +9,13 @@ class ReplaceInfValues(ExportPass): """ - Due to limitation in Qnn, we need to change inf or -inf to arbitrary value in quantization. + Due to limitation in QNN, change inf or -inf to arbitrary value in quantization. """ def __init__(self): super(ReplaceInfValues, self).__init__() - def call(self, graph_module: torch.fx.GraphModule): + def call(self, graph_module: torch.fx.GraphModule): # noqa: C901 for buf_name, tensor in graph_module.named_buffers(): if tensor.is_floating_point(): # 255 here is mainly for attention_mask in Llama for reasonable quant scale @@ -38,5 +38,23 @@ def call(self, graph_module: torch.fx.GraphModule): arg_list[2] = -255 node.args = tuple(arg_list) + if node.target in [ + torch.ops.aten.masked_fill.Tensor, + torch.ops.aten.masked_fill.Scalar, + ]: + assert ( + len(node.args) == 3 + ), f"Expecting {node.name} to have 3 arguments." + val = node.args[2] + if node.args[2] > torch.finfo(torch.float16).max: + val = 255 + elif node.args[2] < torch.finfo(torch.float16).min: + val = -255 + node.args = ( + node.args[0], + node.args[1], + val, + ) + graph_module.recompile() return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 898e2d5b1f6..ae11ba7b325 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -77,15 +77,14 @@ def get_passes_dependency_for_capture_program(): RecomposePixelUnshuffle, RecomposeRmsNorm, RemoveRedundancy, - ReplaceIndexPutInput, TagQuantIO, ) return { AnnotateAdaptiveAvgPool1D: [RemoveRedundancy], AnnotateQuantAttrs: [ - RecomposePixelUnshuffle, ConvertBmmToMatmul, + RecomposePixelUnshuffle, RemoveRedundancy, ], AnnotateStack: [RemoveRedundancy], @@ -106,8 +105,7 @@ def get_passes_dependency_for_capture_program(): ], RecomposePixelUnshuffle: [RemoveRedundancy], RecomposeRmsNorm: [RemoveRedundancy], - ReplaceIndexPutInput: [LayoutTransform], - TagQuantIO: [ReplaceIndexPutInput], + TagQuantIO: [LayoutTransform], } diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 37fe3615268..8d77a5f47aa 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -41,6 +41,8 @@ get_parameter, is_graph_input, is_graph_output, + is_mutable_buffer_input, + is_mutable_buffer_output, is_parameter, ) @@ -307,7 +309,9 @@ def get_tensor_type( node: torch.fx.Node, tensor_type: PyQnnWrapper.Qnn_TensorType_t, ) -> PyQnnWrapper.Qnn_TensorType_t: - is_input = is_graph_input(node, self.edge_program) + is_input = is_graph_input(node, self.edge_program) or is_mutable_buffer_input( + node, self.edge_program + ) is_output = is_graph_output(node) # handle logic for input/output tensors if is_input or is_output: @@ -352,6 +356,33 @@ def get_dynamic_dimension(self, dims): return dynamic_dims if any(dynamic_dims) else [], nominal_dims + def get_tensor_name( + self, + node: torch.fx.Node, + wrapper_idx: int = 0, + ): + tensor_name = f"{node.name}_{wrapper_idx}" + # The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess, + # the input order between QNN and the original graph’s forward function may differ. + # The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime. + # The `output_` is identified as the graph’s output at runtime to prevent confusion with per_tensor_dump. + if is_mutable_buffer_input(node, self.edge_program): + fqn = self.edge_program.graph_signature.inputs_to_buffers[node.target] + position_index = list( + self.edge_program.graph_signature.buffers_to_mutate.values() + ).index(fqn) + tensor_name = f"input_{str(self.external_ids[node])}_mutbuf_{str(position_index)}_{tensor_name}" + elif is_graph_input(node, self.edge_program): + tensor_name = f"input_{str(self.external_ids[node])}_{tensor_name}" + elif is_mutable_buffer_output(node, self.edge_program): + position_index = list( + self.edge_program.graph_signature.buffers_to_mutate.keys() + ).index(node.name) + tensor_name = f"output_mutbuf_{position_index}_{tensor_name}" + elif is_graph_output(node): + tensor_name = f"output_{tensor_name}" + return tensor_name + def define_custom_tensor_wrapper( self, node_name: str, @@ -413,16 +444,7 @@ def define_tensor( if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): return cached - tensor_name = f"{tensor_source_node.name}_{wrapper_idx}" - if is_graph_input(tensor_source_node, self.edge_program): - tensor_name = ( - "input_" - + str(self.external_ids[tensor_source_node]) - + "_" - + tensor_name - ) - if is_graph_output(tensor_source_node): - tensor_name = "output_" + tensor_name + tensor_name = self.get_tensor_name(tensor_source_node, wrapper_idx) dims = torch.Size([1]) if len(tensor.size()) == 0 else tensor.size() dynamic_dims, nominal_dims = self.get_dynamic_dimension(dims) tensor_type = self.get_tensor_type(tensor_source_node, tensor_type) diff --git a/backends/qualcomm/builders/node_visitor_manager.py b/backends/qualcomm/builders/node_visitor_manager.py index fa9d51db1ad..8c1733fcec3 100644 --- a/backends/qualcomm/builders/node_visitor_manager.py +++ b/backends/qualcomm/builders/node_visitor_manager.py @@ -13,7 +13,7 @@ from .node_visitor import NodeVisitor from .op_custom_op import CustomOp -from .utils import is_graph_input, is_graph_output +from .utils import is_graph_input, is_graph_output, is_mutable_buffer_input # This will hold mapping of all node names to the visitor class @@ -39,7 +39,9 @@ def generate_node_to_external_map( # The order in which we visit the placeholder node is same as the *args # order for the forward(*args) signature for this gm. Using the order of # the nodes as external_id to extract the right arg from *args at runtime - if is_graph_input(node, edge_program): + if is_graph_input(node, edge_program) or is_mutable_buffer_input( + node, edge_program + ): node_to_external_map[node] = len(node_to_external_map) for node in edge_program.graph_module.graph.nodes: if is_graph_output(node): diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py index a58075bf06c..de59b1a0489 100644 --- a/backends/qualcomm/builders/op_index_put.py +++ b/backends/qualcomm/builders/op_index_put.py @@ -1,9 +1,10 @@ from typing import Dict import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper - import torch +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS + from .node_visitor import NodeVisitor from .node_visitor_manager import register_node_visitor from .qnn_constants import OpScatterNd, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -22,6 +23,10 @@ def define_node( nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) + # Because the args[0] of index_put op doesn't annotate, need to fill in the quant_attr with the node here. + if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py index c82ebaf1bb3..3345f2e1fc9 100755 --- a/backends/qualcomm/builders/utils.py +++ b/backends/qualcomm/builders/utils.py @@ -75,6 +75,20 @@ def is_graph_input( return tensor.op == "placeholder" and not is_parameter(tensor, edge_program) +def is_mutable_buffer_input( + tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram +) -> bool: + """ + Check if the given tensor is a mutable buffer input + Args: + tensor: EdgeIR Tensor that is being checked for mutable buffer input + """ + if tensor.op == "placeholder" and is_buffer(edge_program, tensor): + fqn = edge_program.graph_signature.inputs_to_buffers[tensor.target] + # if the buffer is mutated then record that + return fqn in edge_program.graph_signature.buffers_to_mutate.values() + + def is_graph_output(node: torch.fx.Node) -> bool: """ Check if the given tensor is used as a graph output @@ -83,7 +97,7 @@ def is_graph_output(node: torch.fx.Node) -> bool: tensor: EdgeIR Tensor that is being checked for graph input """ for user in node.users.keys(): - # getitem node is skiped, check the op_skip_ops.py + # getitem node is skipped, check the op_skip_ops.py if user.op == "output" or ( user.target.__name__ == "getitem" and is_graph_output(user) ): @@ -91,6 +105,25 @@ def is_graph_output(node: torch.fx.Node) -> bool: return False +def is_mutable_buffer_output( + tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram +) -> bool: + """ + Check if the given tensor is a mutable buffer output + Args: + tensor: EdgeIR Tensor that is being checked for mutable buffer output + """ + return ( + any( + user.op == "output" + or user.target.__name__ == "getitem" + and is_graph_output(user) + for user in tensor.users.keys() + ) + and tensor.name in edge_program.graph_signature.buffers_to_mutate.keys() + ) + + def is_constant( tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram ) -> bool: diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index 776923a1493..9a8ce92e739 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy +import logging from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Tuple @@ -29,7 +30,7 @@ Partitioner, PartitionResult, ) -from executorch.exir.backend.utils import tag_constant_data +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import Partition from torch.fx.passes.operator_support import OperatorSupportBase @@ -42,6 +43,9 @@ ) from .utils import filter_fn, generate_qnn_executorch_option, get_skip_decomp_table +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + class QnnOperatorSupport(OperatorSupportBase): def __init__( @@ -124,6 +128,7 @@ def __init__( compiler_specs: List[CompileSpec], skip_node_id_set: set = None, skip_node_op_set: set = None, + skip_mutable_buffer: bool = False, ): self.compiler_specs_snapshot = copy.deepcopy(compiler_specs) @@ -133,6 +138,7 @@ def __init__( self.partition_tags: Dict[str, DelegationSpec] = {} self.skip_node_id_set = set() if skip_node_id_set is None else skip_node_id_set self.skip_node_op_set = set() if skip_node_op_set is None else skip_node_op_set + self.skip_mutable_buffer = skip_mutable_buffer def generate_partitions( self, edge_program: torch.export.ExportedProgram @@ -178,6 +184,15 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu if len(partitions) != 0: self.tag_nodes(partitions, edge_program) tag_constant_data(edge_program) + if not self.skip_mutable_buffer: + logger.info( + "Qnn partitioner will delegate torch mutable buffer with the same I/O address during the runtime, " + "so if your model contains mutable buffer, " + "then you can get the better performance with skip_mutable_buffer=False. " + "If you encounter accuracy issue during the runtime, " + "then please set `skip_mutable_buffer=True` and try again." + ) + tag_mutated_buffer(edge_program) for node in edge_program.graph_module.graph.nodes: if hasattr(node, "meta"): # pop certain keys in meta for not affecting the passes in compilation diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index 02318debfa6..4eee818efe5 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -50,6 +50,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: torch.ops.aten.instance_norm.default, torch.ops.aten.leaky_relu.default, torch.ops.aten.linear.default, + torch.ops.aten.matmul.default, torch.ops.aten.pixel_shuffle.default, torch.ops.aten.pixel_unshuffle.default, torch.ops.aten.prelu.default, diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index ecce4ee3ef0..e1e2ca6dff6 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -462,7 +462,7 @@ def annotate_hardtanh(node: Node, quantization_config: QuantizationConfig) -> No annotate_single_in_single_out(node, quantization_config) -@register_annotator([torch.ops.aten.mean.default]) +@register_annotator([torch.ops.aten.mean.default, torch.ops.aten.mean.dim]) def annotate_mean(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -604,11 +604,6 @@ def annotate_select(node: Node, quantization_config: QuantizationConfig) -> None annotate_single_in_single_out(node, quantization_config) -@register_annotator([torch.ops.aten.mean.dim]) -def annotate_mean_dim(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_single_in_single_out(node, quantization_config) - - @register_annotator([torch.ops.aten.slice.Tensor]) def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -822,16 +817,32 @@ def annotate_index(node: Node, quantization_config: QuantizationConfig) -> None: [torch.ops.aten.index_put.default, torch.ops.aten.index_put_.default] ) def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None: - input = node.args[0] + # Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process. value = node.args[2] input_qspec_map = {} - input_qspec_map[input] = quantization_config.input_activation - input_qspec_map[value] = SharedQuantizationSpec((input, node)) + input_qspec_map[value] = quantization_config.input_activation + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=SharedQuantizationSpec((value, node)), + _annotated=True, + ) + + +@register_annotator( + [torch.ops.aten.index_copy.default, torch.ops.aten.index_copy_.default] +) +def annotate_index_copy(node: Node, quantization_config: QuantizationConfig) -> None: + # Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process. + value = node.args[3] + + input_qspec_map = {} + input_qspec_map[value] = quantization_config.input_activation node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=SharedQuantizationSpec((input, node)), + output_qspec=SharedQuantizationSpec((value, node)), _annotated=True, ) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 0e06015ed91..0024b52dbe9 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -26,6 +26,35 @@ ) +def annotate_eurobert(gm: torch.fx.GraphModule): + """ + QNN does not support int32 -> signed 16bit quant + We need to first annotate this to_fp node as 8bit quant, so it will perform requantize + Final graph should look like: int32 -> convert -> cast -> matmul.args[1] + + """ + quantization_config_8a8w = get_8a8w_qnn_ptq_config() + for node in gm.graph.nodes: + # A little tricky here. This matmul node is wrapped inside a submodule after 1st torch.export. + # There are actually 2 'to' op that is redundant. + # It will look like: int64 -> to_fp -> to_fp -> matmul.args[1] + # Draw out the graph after the 1st export will help visualize the submodule. + + if node.target == torch.ops.aten.matmul.default and node.args[1].args[0].args[ + 0 + ].meta["val"].dtype in [torch.int64, torch.int32]: + to_node = node.args[1] + input_qspec_map = {} + assert isinstance(to_node, Node) + input_spec = quantization_config_8a8w.input_activation + input_qspec_map[to_node] = input_spec + to_node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config_8a8w.output_activation, + _annotated=True, + ) + + def annotate_mimi_decoder(gm: torch.fx.GraphModule): """ The 1st transpose conv in mimi decoder is really sensitive to scale/offset in 16a8w, which causes execution failure. @@ -204,7 +233,7 @@ def annotate_matmul_input1(node: Node): ) quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config( act_dtype=torch.uint8, - weight_dtype="int4", + weight_dtype=torch.int4, act_observer=MinMaxObserver, act_symmetric=True, ) @@ -263,14 +292,15 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig): ) def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None: - input = node.args[0] + # Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process. value = node.args[2] + input_qspec_map = {} - input_qspec_map[input] = quantization_config.input_activation - input_qspec_map[value] = SharedQuantizationSpec((input, node)) + input_qspec_map[value] = quantization_config.input_activation + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=SharedQuantizationSpec((input, node)), + output_qspec=SharedQuantizationSpec((value, node)), _annotated=True, ) diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index e2a9cd83567..748128ceafd 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -241,8 +241,7 @@ def get_ptq_per_channel_quant_config( torch.int8, torch.int16, } - # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype - supported_weight_dtypes = {"int4", torch.int8, torch.int16} + supported_weight_dtypes = {torch.int4, torch.int8, torch.int16} assert ( act_dtype in supported_act_types ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" @@ -276,9 +275,11 @@ def get_ptq_per_channel_quant_config( ) weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, ch_axis=0, observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), @@ -310,9 +311,11 @@ def get_ptq_per_block_quant_config( act_symmetric=act_symmetric, ) weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, ch_axis=0, observer_or_fake_quant_ctr=PerBlockParamObserver.with_args(**extra_args), @@ -463,8 +466,7 @@ def get_qat_per_channel_quant_config( torch.int8, torch.int16, } - # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype - supported_weight_dtypes = {"int4", torch.int8, torch.int16} + supported_weight_dtypes = {torch.int4, torch.int8, torch.int16} assert ( act_dtype in supported_act_types ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" @@ -491,17 +493,21 @@ def get_qat_per_channel_quant_config( ) weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, ch_axis=0, observer=MovingAveragePerChannelMinMaxObserver, ) weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, ch_axis=0, observer_or_fake_quant_ctr=weight_fake_quant_ctr, diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 9a149e7db87..7298e02aa0c 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -85,7 +85,7 @@ class QuantDtype(IntEnum): partial( get_ptq_per_channel_quant_config, act_dtype=torch.uint16, - weight_dtype="int4", + weight_dtype=torch.int4, ), None, ), @@ -94,12 +94,12 @@ class QuantDtype(IntEnum): partial( get_ptq_per_channel_quant_config, act_dtype=torch.uint16, - weight_dtype="int4", + weight_dtype=torch.int4, ), partial( get_ptq_per_block_quant_config, act_dtype=torch.uint16, - weight_dtype="int4", + weight_dtype=torch.int4, ), ), (QuantDtype.use_8a8w, False): ( @@ -113,7 +113,7 @@ class QuantDtype(IntEnum): partial( get_qat_per_channel_quant_config, act_dtype=torch.uint16, - weight_dtype="int4", + weight_dtype=torch.int4, ), None, ), diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp index ab038404582..01bf13603d6 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp @@ -129,33 +129,37 @@ Error QnnExecuTorchBackend::execute( std::vector input_tensor_structs; std::vector output_tensor_structs; + int args_index = 0; input_tensor_structs.reserve(input_tensors.size()); - for (int i = 0; i < input_tensors.size(); ++i) { - if (qnn_manager->RegisterMem( - args[i]->toTensor().mutable_data_ptr(), input_tensors[i]) != - Error::Ok) { - // update data ptr only should be fine - input_tensors[i]->FillDataBuffer( - args[i]->toTensor().const_data_ptr(), false /* copy_data */); + for (const auto& input_tensor : input_tensors) { + if (input_tensor->GetName().find("mutbuf_") == std::string::npos) { + if (qnn_manager->RegisterMem( + args[args_index]->toTensor().mutable_data_ptr(), input_tensor) != + Error::Ok) { + // update data ptr only should be fine + input_tensor->FillDataBuffer( + args[args_index]->toTensor().const_data_ptr(), + false /* copy_data */); + // use the real input shape instead of nominal one to make sure + // dynamic shape is functional + auto dims = args[args_index]->toTensor().sizes(); + input_tensor->SetDims(dims.data(), dims.size()); + } + args_index++; } - // use the real input shape instead of nominal one to make sure - // dynamic shape is functional - auto dims = args[i]->toTensor().sizes(); - input_tensors[i]->SetDims(dims.data(), dims.size()); - input_tensor_structs.emplace_back(input_tensors[i]->CloneTensorStruct()); + input_tensor_structs.emplace_back(input_tensor->CloneTensorStruct()); } - int output_index = input_tensors.size(); for (const auto& output_tensor : output_tensors) { // pos=0 limits the search to the prefix - if (output_tensor->GetName().rfind("output_", 0) == 0) { - void* mutable_data_ptr = - args[output_index]->toTensor().mutable_data_ptr(); + if (output_tensor->GetName().rfind("output_", 0) == 0 && + output_tensor->GetName().find("mutbuf_") == std::string::npos) { + void* mutable_data_ptr = args[args_index]->toTensor().mutable_data_ptr(); if (qnn_manager->RegisterMem(mutable_data_ptr, output_tensor) != Error::Ok) { output_tensor->FillDataBuffer(mutable_data_ptr, false /* copy_data */); } - output_index++; + args_index++; } output_tensor_structs.push_back(output_tensor->CloneTensorStruct()); } diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index 0f64e8b9cce..0dd0470a2b0 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -18,6 +18,7 @@ #include #include #include +#include namespace executorch { namespace backends { @@ -35,6 +36,16 @@ bool CompareExportedInput( return numA < numB; } +int ExtractMutableBufferNumber(const std::string& name) { + std::string prefix = "mutbuf_"; + size_t startPos = name.find(prefix); + if (startPos != std::string::npos) { + startPos += prefix.length(); + return std::stoi(name.substr(startPos)); + } + return -1; +} + QnnManager::~QnnManager() { Destroy(); } @@ -363,9 +374,21 @@ Error QnnManager::AllocateTensor(const std::string& graph_name) { std::vector output_tensors = backend_params_ptr_->qnn_context_ptr_->GetGraphOutputs(graph_name); + // Mapping memory address for the input and output of mutable buffer + std::unordered_map mutable_buffer_id_to_memory_map; + for (auto& tensor : input_tensors) { std::shared_ptr tensor_wrapper = CreateTensorWrapper(tensor); tensor_wrapper->UpdateQnnTensorMeta(tensor); + + int mutable_buffer_id = + ExtractMutableBufferNumber(tensor_wrapper->GetName()); + if (mutable_buffer_id != -1) { + // Delegate maintains the memory for mutable buffer + tensor_wrapper->AllocateDataBuffer(); + mutable_buffer_id_to_memory_map[mutable_buffer_id] = + tensor_wrapper->GetStaticTensorData(); + } input_tensors_[graph_name].emplace_back(std::move(tensor_wrapper)); } if (!options_->is_from_context_binary()) { @@ -388,6 +411,16 @@ Error QnnManager::AllocateTensor(const std::string& graph_name) { if (IsTensorDump()) { tensor_wrapper->AllocateDataBuffer(); } + int mutable_buffer_id = + ExtractMutableBufferNumber(tensor_wrapper->GetName()); + if (mutable_buffer_id != -1 && + mutable_buffer_id_to_memory_map.find(mutable_buffer_id) != + mutable_buffer_id_to_memory_map.end()) { + // Fill the same memory for I/O of mutable buffer + tensor_wrapper->FillDataBuffer( + mutable_buffer_id_to_memory_map[mutable_buffer_id], + false /* copy_data */); + } output_tensors_[graph_name].emplace_back(std::move(tensor_wrapper)); } return Error::Ok; diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 091c2d94cd0..8be05d46688 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -909,17 +909,35 @@ def forward(self, x): return self.dispatcher[self.axis](x) +class IndexCopy(torch.nn.Module): + def __init__(self, skip_mutable_buffer=False): + super().__init__() + self.skip_mutable_buffer = skip_mutable_buffer + self.register_buffer( + "k_cache", + torch.zeros((1, 1024, 12, 64), dtype=torch.float32), + persistent=True, + ) + + def forward(self, input_pos, k_val): + k_out = self.k_cache + k_out.index_copy_(1, input_pos, k_val) + return k_out + 0 + + class IndexPut(torch.nn.Module): - def __init__(self): + def __init__(self, skip_mutable_buffer=False): super().__init__() + self.skip_mutable_buffer = skip_mutable_buffer self.register_buffer( "k_cache", torch.zeros((1, 1024, 12, 64), dtype=torch.float32), + persistent=True, ) def forward(self, input_pos, k_val): k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val) - return k_out + return k_out + 0 class InstanceNorm2d(torch.nn.Module): diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 4f101db8e6e..747a6804957 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -618,13 +618,55 @@ def test_qnn_backend_index(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_index_copy(self): + test_comb = [ + { + QCOM_MODULE: IndexCopy(skip_mutable_buffer=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int64), + torch.randn([1, 1, 12, 64]), + ), + }, + { + QCOM_MODULE: IndexCopy(skip_mutable_buffer=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int64), + torch.randn([1, 1, 12, 64]), + ), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + self.lower_module_and_test_output( + test[QCOM_MODULE], + test[QCOM_SAMPLE_INPUTS], + skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer, + ) + def test_qnn_backend_index_put(self): - module = IndexPut() # noqa: F405 - sample_input = ( - torch.tensor([2], dtype=torch.int32), - torch.randn([1, 1, 12, 64]), - ) - self.lower_module_and_test_output(module, sample_input) + test_comb = [ + { + QCOM_MODULE: IndexPut(skip_mutable_buffer=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), + ), + }, + { + QCOM_MODULE: IndexPut(skip_mutable_buffer=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), + ), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + self.lower_module_and_test_output( + test[QCOM_MODULE], + test[QCOM_SAMPLE_INPUTS], + skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer, + ) def test_qnn_backend_instance_norm_2d(self): modules = [InstanceNorm2d(32), InstanceNorm2d(32, affine=False)] # noqa: F405 @@ -1860,14 +1902,61 @@ def test_qnn_backend_index(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_index_copy(self): + test_comb = [ + { + QCOM_MODULE: IndexCopy(skip_mutable_buffer=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int64), + torch.randn([1, 1, 12, 64]), + ), + }, + { + QCOM_MODULE: IndexCopy(skip_mutable_buffer=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int64), + torch.randn([1, 1, 12, 64]), + ), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + module = self.get_qdq_module( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output( + module, + test[QCOM_SAMPLE_INPUTS], + skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer, + ) + def test_qnn_backend_index_put(self): - module = IndexPut() # noqa: F405 - sample_input = ( - torch.tensor([2], dtype=torch.int32), - torch.randn([1, 1, 12, 64]), - ) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + test_comb = [ + { + QCOM_MODULE: IndexPut(skip_mutable_buffer=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), + ), + }, + { + QCOM_MODULE: IndexPut(skip_mutable_buffer=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), + ), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + module = self.get_qdq_module( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output( + module, + test[QCOM_SAMPLE_INPUTS], + skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer, + ) def test_qnn_backend_instance_norm_2d(self): modules = [InstanceNorm2d(32), InstanceNorm2d(32, affine=False)] # noqa: F405 @@ -3030,7 +3119,17 @@ def test_qnn_backend_generate_optrace(self): for _, (optrace, qhas) in binaries_trace.items(): with open(optrace, "r") as optrace_file: optrace_data = json.load(optrace_file) - for row in optrace_data: + # { + # header: + # { + # 'header_version': {'major': x, 'minor': y, 'patch': z}, + # 'version': {'major': x, 'minor': y, 'patch': z}, + # 'artifact_type': 'OP_TRACE' + # } + # traceEvents: + # {...} + # } + for row in optrace_data["traceEvents"]: self.assertIn("pid", row) with open(qhas, "r") as qhas_file: qhas_data = json.load(qhas_file) @@ -3726,7 +3825,17 @@ def test_qnn_backend_generate_optrace(self): for _, (optrace, qhas) in binaries_trace.items(): with open(optrace, "r") as optrace_file: optrace_data = json.load(optrace_file) - for row in optrace_data: + # { + # header: + # { + # 'header_version': {'major': x, 'minor': y, 'patch': z}, + # 'version': {'major': x, 'minor': y, 'patch': z}, + # 'artifact_type': 'OP_TRACE' + # } + # traceEvents: + # {...} + # } + for row in optrace_data["traceEvents"]: self.assertIn("pid", row) with open(qhas, "r") as qhas_file: qhas_data = json.load(qhas_file) @@ -3891,6 +4000,74 @@ def test_llama_stories_110m(self): class TestExampleOssScript(TestQNN): + def test_albert(self): + if not self.required_envs([self.sentence_dataset]): + self.skipTest("missing required envs") + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/albert.py", + "--dataset", + self.sentence_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["accuracy"], 0.8) + + def test_bert(self): + if not self.required_envs([self.sentence_dataset]): + self.skipTest("missing required envs") + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/bert.py", + "--dataset", + self.sentence_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["accuracy"], 0.6) + def test_conv_former(self): if not self.required_envs([self.image_dataset]): self.skipTest("missing required envs") @@ -4033,6 +4210,40 @@ def test_dino_v2(self): self.assertGreaterEqual(msg["top_1"], 70) self.assertGreaterEqual(msg["top_5"], 85) + def test_distilbert(self): + if not self.required_envs([self.sentence_dataset]): + self.skipTest("missing required envs") + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/distilbert.py", + "--dataset", + self.sentence_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["accuracy"], 0.45) + def test_dit(self): if not self.required_envs(): self.skipTest("missing required envs") @@ -4142,14 +4353,13 @@ def test_efficientSAM(self): else: self.assertGreaterEqual(msg["MIoU"], 0.55) - def test_swin_transformer(self): - if not self.required_envs([self.image_dataset]): + def test_esrgan(self): + if not self.required_envs(): self.skipTest("missing required envs") + cmds = [ "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/swin_transformer.py", - "--dataset", - self.image_dataset, + f"{self.executorch_root}/examples/qualcomm/oss_scripts/esrgan.py", "--artifact", self.artifact_dir, "--build_folder", @@ -4158,6 +4368,9 @@ def test_swin_transformer(self): self.device, "--model", self.model, + "--default_dataset", + "--oss_repo", + self.oss_repo, "--ip", self.ip, "--port", @@ -4174,16 +4387,17 @@ def test_swin_transformer(self): if "Error" in msg: self.fail(msg["Error"]) else: - self.assertGreaterEqual(msg["top_1"], 60) - self.assertGreaterEqual(msg["top_5"], 80) + self.assertGreaterEqual(msg["PSNR"], 24) + self.assertGreaterEqual(msg["SSIM"], 0.8) - def test_esrgan(self): - if not self.required_envs(): + def test_eurobert(self): + if not self.required_envs([self.sentence_dataset]): self.skipTest("missing required envs") - cmds = [ "python", - f"{self.executorch_root}/examples/qualcomm/oss_scripts/esrgan.py", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/eurobert.py", + "--dataset", + self.sentence_dataset, "--artifact", self.artifact_dir, "--build_folder", @@ -4192,9 +4406,6 @@ def test_esrgan(self): self.device, "--model", self.model, - "--default_dataset", - "--oss_repo", - self.oss_repo, "--ip", self.ip, "--port", @@ -4211,8 +4422,7 @@ def test_esrgan(self): if "Error" in msg: self.fail(msg["Error"]) else: - self.assertGreaterEqual(msg["PSNR"], 24) - self.assertGreaterEqual(msg["SSIM"], 0.8) + self.assertGreaterEqual(msg["accuracy"], 0.5) def test_fastvit(self): if not self.required_envs( @@ -4363,7 +4573,7 @@ def test_gMLP(self): self.fail(msg["Error"]) else: self.assertGreaterEqual(msg["top_1"], 60) - self.assertGreaterEqual(msg["top_5"], 90) + self.assertGreaterEqual(msg["top_5"], 85) @unittest.skip("Only outputs good accuracy in QNN 2.29") def test_mobilevit_v2(self): @@ -4654,6 +4864,41 @@ def test_ssd300_vgg16(self): else: self.assertGreaterEqual(msg["mAP"], 0.70) + def test_swin_transformer(self): + if not self.required_envs([self.image_dataset]): + self.skipTest("missing required envs") + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/swin_transformer.py", + "--dataset", + self.image_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 60) + self.assertGreaterEqual(msg["top_5"], 80) + class TestExampleQaihubScript(TestQNN): def test_utils_export(self): diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 2968086d7a5..2e923b92250 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -462,6 +462,7 @@ def lower_module_and_test_output( passes_job: Optional[OrderedDict] = None, skip_node_id_set: set = None, skip_node_op_set: set = None, + skip_mutable_buffer: bool = False, dynamic_shapes: Dict = None, ): delegated_program = to_edge_transform_and_lower_to_qnn( @@ -472,6 +473,7 @@ def lower_module_and_test_output( passes_job=passes_job, skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, + skip_mutable_buffer=skip_mutable_buffer, ) # this is needed for the ETRecord as lowering modifies the graph in-place diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 45f08c8f2c1..3471b0155bd 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -332,6 +332,7 @@ def to_edge_transform_and_lower_to_qnn( passes_job: Optional[Union[OrderedDict, Dict[str, OrderedDict]]] = None, skip_node_id_set: Optional[set] = None, skip_node_op_set: Optional[set] = None, + skip_mutable_buffer: bool = False, ) -> EdgeProgramManager: """ Transforms and lowers a given PyTorch module to the QNN backend. @@ -356,6 +357,8 @@ def to_edge_transform_and_lower_to_qnn( Set of node IDs to skip during partitioning. skip_node_op_set (Optional[set]): Set of node operations to skip during partitioning. + skip_mutable_buffer (Optional[set]): + Whether to skip delegating the mutable buffer in QNN backend. Returns: EdgeProgramManager: @@ -407,6 +410,7 @@ def ensure_graph_specific_dict(value, graph_names): compiler_specs[graph_name], skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, + skip_mutable_buffer=skip_mutable_buffer, ) ] for graph_name in graph_names diff --git a/backends/test/harness/TARGETS b/backends/test/harness/TARGETS new file mode 100644 index 00000000000..41d9a5b7682 --- /dev/null +++ b/backends/test/harness/TARGETS @@ -0,0 +1,18 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "tester", + srcs = [ + "__init__.py", + "tester.py", + ] + native.glob(["stages/*.py"]), + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/exir:graph_module", + ], +) diff --git a/backends/test/harness/__init__.py b/backends/test/harness/__init__.py new file mode 100644 index 00000000000..3660e979611 --- /dev/null +++ b/backends/test/harness/__init__.py @@ -0,0 +1,3 @@ +from .tester import Tester + +__all__ = ["Tester"] diff --git a/backends/test/harness/stages/__init__.py b/backends/test/harness/stages/__init__.py new file mode 100644 index 00000000000..36ed435ebd7 --- /dev/null +++ b/backends/test/harness/stages/__init__.py @@ -0,0 +1,22 @@ +from .export import Export +from .partition import Partition +from .quantize import Quantize +from .run_passes import RunPasses +from .serialize import Serialize +from .stage import Stage, StageType +from .to_edge import ToEdge +from .to_edge_transform_and_lower import ToEdgeTransformAndLower +from .to_executorch import ToExecutorch + +__all__ = [ + "Export", + "Partition", + "Quantize", + "RunPasses", + "Serialize", + "Stage", + "StageType", + "ToEdge", + "ToEdgeTransformAndLower", + "ToExecutorch", +] diff --git a/backends/test/harness/stages/export.py b/backends/test/harness/stages/export.py new file mode 100644 index 00000000000..53bef6bb083 --- /dev/null +++ b/backends/test/harness/stages/export.py @@ -0,0 +1,32 @@ +from typing import Any, Optional, Tuple + +import torch + +from executorch.backends.test.harness.stages.stage import Stage, StageType +from torch.export import export, ExportedProgram + + +class Export(Stage): + def __init__(self, dynamic_shapes: Optional[Tuple[Any]] = None): + self.exported_program = None + self.dynamic_shapes = dynamic_shapes + + def stage_type(self) -> StageType: + return StageType.EXPORT + + def run( + self, + artifact: torch.nn.Module, + inputs: Tuple[torch.Tensor], + ) -> None: + self.exported_program = export( + artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True + ) + + @property + def artifact(self) -> ExportedProgram: + return self.exported_program + + @property + def graph_module(self) -> str: + return self.exported_program.graph_module diff --git a/backends/test/harness/stages/partition.py b/backends/test/harness/stages/partition.py new file mode 100644 index 00000000000..f1a2984fb5f --- /dev/null +++ b/backends/test/harness/stages/partition.py @@ -0,0 +1,26 @@ +from executorch.backends.test.harness.stages.stage import Stage, StageType +from executorch.exir import EdgeProgramManager +from executorch.exir.backend.backend_api import validation_disabled +from executorch.exir.backend.partitioner import Partitioner + + +class Partition(Stage): + def __init__(self, partitioner: Partitioner): + self.partitioner = partitioner + self.delegate_module = None + + def stage_type(self) -> StageType: + return StageType.PARTITION + + def run(self, artifact: EdgeProgramManager, inputs=None): + with validation_disabled(): + self.delegate_module = artifact + self.delegate_module = self.delegate_module.to_backend(self.partitioner) + + @property + def artifact(self) -> EdgeProgramManager: + return self.delegate_module + + @property + def graph_module(self) -> str: + return self.delegate_module.exported_program().graph_module diff --git a/backends/test/harness/stages/quantize.py b/backends/test/harness/stages/quantize.py new file mode 100644 index 00000000000..e03db058080 --- /dev/null +++ b/backends/test/harness/stages/quantize.py @@ -0,0 +1,79 @@ +from typing import Any, Optional, Sequence, Tuple + +import torch + +from executorch.backends.test.harness.stages.stage import Stage, StageType +from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( + DuplicateDynamicQuantChainPass, +) + +from torch.export import export_for_training + +from torchao.quantization.pt2e.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torchao.quantization.pt2e.quantizer import Quantizer + + +class Quantize(Stage): + def __init__( + self, + quantizer: Optional[Quantizer] = None, + quantization_config: Optional[Any] = None, + calibrate: bool = True, + calibration_samples: Optional[Sequence[Any]] = None, + is_qat: Optional[bool] = False, + ): + self.quantizer = quantizer + self.quantization_config = quantization_config + self.calibrate = calibrate + self.calibration_samples = calibration_samples + + self.quantizer.set_global(self.quantization_config) + + self.converted_graph = None + self.is_qat = is_qat + + def stage_type(self) -> str: + return StageType.QUANTIZE + + def run( + self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] + ) -> None: + assert inputs is not None + if self.is_qat: + artifact.train() + captured_graph = export_for_training(artifact, inputs, strict=True).module() + + assert isinstance(captured_graph, torch.fx.GraphModule) + + if self.is_qat: + prepared = prepare_qat_pt2e(captured_graph, self.quantizer) + else: + prepared = prepare_pt2e(captured_graph, self.quantizer) + + if self.calibrate: + # Calibrate prepared model to provide data to quantization observers. + if self.calibration_samples is not None: + for inp in self.calibration_samples: + prepared(*inp) + else: + prepared(*inputs) + + converted = convert_pt2e(prepared) + DuplicateDynamicQuantChainPass()(converted) + + self.converted_graph = converted + + @property + def artifact(self) -> torch.fx.GraphModule: + return self.converted_graph + + @property + def graph_module(self) -> str: + return self.converted_graph + + def run_artifact(self, inputs): + return self.converted_graph.forward(*inputs) diff --git a/backends/test/harness/stages/run_passes.py b/backends/test/harness/stages/run_passes.py new file mode 100644 index 00000000000..b72c40d5337 --- /dev/null +++ b/backends/test/harness/stages/run_passes.py @@ -0,0 +1,66 @@ +from typing import Callable, List, Optional, Type, Union + +from executorch.backends.test.harness.stages.stage import Stage, StageType +from executorch.exir import EdgeProgramManager +from executorch.exir.program._program import _transform +from torch._export.pass_base import PassType +from torch.export import ExportedProgram + + +class RunPasses(Stage): + def __init__( + self, + pass_manager_cls: Type, + pass_list: Optional[List[Type[PassType]]] = None, + pass_functions: Optional[List[Callable]] = None, + ): + self.pass_manager_cls = pass_manager_cls + self.pass_list = pass_list + self.pass_functions = pass_functions + self.edge_or_aten_program = None + + def stage_type(self) -> StageType: + return StageType.RUN_PASSES + + def run( + self, artifact: Union[EdgeProgramManager, ExportedProgram], inputs=None + ) -> None: + if isinstance(artifact, EdgeProgramManager): + self.edge_or_aten_program = artifact + if self.pass_list: + pass_manager = self.pass_manager_cls( + artifact.exported_program(), self.pass_list + ) + self.edge_or_aten_program._edge_programs["forward"] = ( + pass_manager.transform() + ) + if self.pass_functions: + assert isinstance(self.pass_functions, list) + for pass_function in self.pass_functions: + self.edge_or_aten_program._edge_programs["forward"] = pass_function( + self.edge_or_aten_program.exported_program() + ) + else: + transformed_ep = artifact + if self.pass_list: + assert isinstance(self.pass_list, list) + for pass_ in self.pass_list: + transformed_ep = _transform(transformed_ep, pass_()) + + if self.pass_functions: + assert isinstance(self.pass_functions, list) + for pass_function in self.pass_functions: + transformed_ep = pass_function(transformed_ep) + + self.edge_or_aten_program = transformed_ep + + @property + def artifact(self) -> Union[EdgeProgramManager, ExportedProgram]: + return self.edge_or_aten_program + + @property + def graph_module(self) -> str: + if isinstance(self.edge_or_aten_program, EdgeProgramManager): + return self.edge_or_aten_program.exported_program().graph_module + else: + return self.edge_or_aten_program.graph_module diff --git a/backends/test/harness/stages/serialize.py b/backends/test/harness/stages/serialize.py new file mode 100644 index 00000000000..9d0bded0483 --- /dev/null +++ b/backends/test/harness/stages/serialize.py @@ -0,0 +1,56 @@ +import copy +import logging + +from typing import Optional + +from executorch.backends.test.harness.stages.stage import Stage, StageType +from executorch.exir import ExecutorchProgramManager + +from torch.utils._pytree import tree_flatten + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +try: + from executorch.extension.pybindings.portable_lib import ( # @manual + _load_for_executorch_from_buffer, + ) +except ImportError as e: + logger.warning(f"{e=}") + pass + + +class Serialize(Stage): + def __init__(self): + self.buffer = None + + def stage_type(self) -> StageType: + return StageType.SERIALIZE + + def run(self, artifact: ExecutorchProgramManager, inputs=None) -> None: + self.buffer = artifact.buffer + + @property + def artifact(self) -> bytes: + return self.buffer + + @property + def graph_module(self) -> None: + return None + + def run_artifact(self, inputs): + inputs_flattened, _ = tree_flatten(inputs) + executorch_module = _load_for_executorch_from_buffer(self.buffer) + executorch_output = copy.deepcopy( + executorch_module.run_method("forward", tuple(inputs_flattened)) + ) + return executorch_output + + def dump_artifact(self, path_to_dump: Optional[str]): + """ + dump_artifact is overridden to dump the serialized bytes into pte file + """ + if not path_to_dump: + raise RuntimeError("path_to_dump file not provided") + else: + with open(path_to_dump, "wb") as f: + f.write(self.artifact) diff --git a/backends/test/harness/stages/stage.py b/backends/test/harness/stages/stage.py new file mode 100644 index 00000000000..f1f2604b766 --- /dev/null +++ b/backends/test/harness/stages/stage.py @@ -0,0 +1,91 @@ +from abc import ABC, abstractmethod +from enum import Enum +from typing import Optional + +from executorch.exir import EdgeProgramManager + +from torch.export import ExportedProgram + + +class StageType(Enum): + QUANTIZE = 0 + EXPORT = 1 + RUN_PASSES = 2 + TO_EDGE = 3 + TO_EDGE_TRANSFORM_AND_LOWER = 4 + PARTITION = 5 + TO_EXECUTORCH = 6 + SERIALIZE = 7 + INITIAL_MODEL = 8 + + +class Stage(ABC): + """ + Interface for a Stage in the PT2.0 lowering pipeline + """ + + @abstractmethod + def stage_type(self) -> StageType: + """ + Returns the type of the stage. + """ + pass + + @abstractmethod + def run(self, artifact, inputs): + """ + Executes this stage, generates the 'artifact', for later stages. + """ + pass + + @property + @abstractmethod + def artifact(self): + """ + Returns the artifact generated by this stage. To be used by the next stage in the pipeline. + """ + pass + + @property + @abstractmethod + def graph_module(self): + """ + Return the artifact's graph module for this stage + """ + pass + + def run_artifact(self, inputs): + """ + Returns the output of calling the artifact generated by this stage with inputs + """ + if isinstance(self.artifact, ExportedProgram): + return self.artifact(*inputs) + else: + return self.artifact.exported_program().module()(*inputs) + + # Debug Tools for stages + def artifact_str(self): + """ + Return string printable artifact for this stage + """ + if isinstance(self.artifact, EdgeProgramManager): + return self.artifact.exported_program() + return self.artifact + + def stage_banner(self): + """ + Returns banner string for this stage + """ + return "#" * 36 + " " + str(self.__class__.__name__) + " " + "#" * 36 + "\n" + + def dump_artifact(self, path_to_dump: Optional[str]): + """ + Dumps string printable artifact to path. If path_to_dump, then it is printed to terminal + """ + if path_to_dump: + with open(path_to_dump, "a") as fp: + fp.write(str(self.stage_banner() + "\n")) + fp.write(str(self.artifact_str())) + else: + print(self.stage_banner() + "\n") + print(self.artifact_str()) diff --git a/backends/test/harness/stages/to_edge.py b/backends/test/harness/stages/to_edge.py new file mode 100644 index 00000000000..460b6371c9e --- /dev/null +++ b/backends/test/harness/stages/to_edge.py @@ -0,0 +1,27 @@ +from typing import Optional + +from executorch.backends.test.harness.stages.stage import Stage, StageType +from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge +from torch.export import ExportedProgram + + +class ToEdge(Stage): + def __init__(self, edge_compile_config: Optional[EdgeCompileConfig] = None): + self.edge_compile_conf = edge_compile_config or EdgeCompileConfig() + self.edge_dialect_program = None + + def stage_type(self) -> StageType: + return StageType.TO_EDGE + + def run(self, artifact: ExportedProgram, inputs=None) -> None: + self.edge_dialect_program = to_edge( + artifact, compile_config=self.edge_compile_conf + ) + + @property + def artifact(self) -> EdgeProgramManager: + return self.edge_dialect_program + + @property + def graph_module(self) -> str: + return self.edge_dialect_program.exported_program().graph_module diff --git a/backends/test/harness/stages/to_edge_transform_and_lower.py b/backends/test/harness/stages/to_edge_transform_and_lower.py new file mode 100644 index 00000000000..6c5aa4b541b --- /dev/null +++ b/backends/test/harness/stages/to_edge_transform_and_lower.py @@ -0,0 +1,40 @@ +from typing import List, Optional, Type + +from executorch.backends.test.harness.stages.stage import Stage, StageType +from executorch.exir import ( + EdgeCompileConfig, + EdgeProgramManager, + to_edge_transform_and_lower, +) +from executorch.exir.backend.partitioner import Partitioner +from torch.export import ExportedProgram + + +class ToEdgeTransformAndLower(Stage): + def __init__( + self, + default_partitioner_cls: Type, + partitioners: Optional[List[Partitioner]] = None, + edge_compile_config: Optional[EdgeCompileConfig] = None, + ): + self.partitioners = partitioners or [default_partitioner_cls()] + self.edge_compile_conf = edge_compile_config or EdgeCompileConfig() + self.edge_dialect_program = None + + def stage_type(self) -> StageType: + return StageType.TO_EDGE_TRANSFORM_AND_LOWER + + def run(self, artifact: ExportedProgram, inputs=None) -> None: + self.edge_dialect_program = to_edge_transform_and_lower( + artifact, + compile_config=self.edge_compile_conf, + partitioner=self.partitioners, + ) + + @property + def artifact(self) -> EdgeProgramManager: + return self.edge_dialect_program + + @property + def graph_module(self) -> str: + return self.edge_dialect_program.exported_program().graph_module diff --git a/backends/test/harness/stages/to_executorch.py b/backends/test/harness/stages/to_executorch.py new file mode 100644 index 00000000000..d3154e6bc2d --- /dev/null +++ b/backends/test/harness/stages/to_executorch.py @@ -0,0 +1,54 @@ +import sys + +from typing import Optional + +from executorch.backends.test.harness.stages.stage import Stage, StageType +from executorch.exir import ( + EdgeProgramManager, + ExecutorchBackendConfig, + ExecutorchProgramManager, +) +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass +from executorch.exir.print_program import pretty_print, print_program + + +class ToExecutorch(Stage): + def __init__( + self, + config: Optional[ExecutorchBackendConfig] = None, + ): + self.config = config or ExecutorchBackendConfig( + extract_delegate_segments=True, + sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), + ) + self.executorch_program = None + + def stage_type(self) -> StageType: + return StageType.TO_EXECUTORCH + + def run(self, artifact: EdgeProgramManager, inputs=None): + self.executorch_program = artifact.to_executorch(self.config) + + @property + def artifact(self) -> ExecutorchProgramManager: + return self.executorch_program + + @property + def graph_module(self) -> str: + return self.executorch_program().graph_module + + def dump_artifact(self, path_to_dump: Optional[str]): + """ + dump_artifact is overridden to dump the serialized program + """ + original_stdout = sys.stdout + + sys.stdout = open(path_to_dump, "a") if path_to_dump else sys.stdout + print(self.stage_banner() + "\n") + pretty_print(self.artifact._emitter_output.program) + print_program( + self.artifact._emitter_output.program, + show_meminfo=True, + mark_dynamic_shape_tensor=True, + ) + sys.stdout = original_stdout diff --git a/backends/test/harness/tester.py b/backends/test/harness/tester.py new file mode 100644 index 00000000000..f1dfeb23531 --- /dev/null +++ b/backends/test/harness/tester.py @@ -0,0 +1,453 @@ +import random +from collections import Counter, OrderedDict +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from executorch.backends.test.harness.stages import ( + Export, + Partition, + Quantize, + RunPasses, + Serialize, + Stage, + StageType, + ToEdge, + ToEdgeTransformAndLower, + ToExecutorch, +) +from executorch.exir.dim_order_utils import get_memory_format + +from torch.export import ExportedProgram +from torch.testing import FileCheck + + +class Tester: + """ + Base class for a backend tester. This class is not intended to be used directly. Instead, + backends are expected to subclass it and provide implementations for backend-dependent + stages. + """ + + def __init__( + self, + module: torch.nn.Module, + example_inputs: Tuple[torch.Tensor], + stage_classes: Dict[StageType, Type], + dynamic_shapes: Optional[Tuple[Any]] = None, + ): + module.eval() + + self.stage_classes = stage_classes + self.original_module = module + self.example_inputs = example_inputs + self.dynamic_shapes = dynamic_shapes + self.stages: Dict[StageType, Stage] = OrderedDict.fromkeys(list(StageType)) + self.pipeline = { + StageType.QUANTIZE: [StageType.EXPORT], + StageType.EXPORT: [ + StageType.RUN_PASSES, + StageType.TO_EDGE, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + ], + StageType.TO_EDGE_TRANSFORM_AND_LOWER: [ + StageType.RUN_PASSES, + StageType.TO_EXECUTORCH, + ], + StageType.TO_EDGE: [ + StageType.PARTITION, + StageType.RUN_PASSES, + ], + StageType.RUN_PASSES: [ + StageType.PARTITION, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + ], + # TODO Make this Stage optional + StageType.PARTITION: [StageType.TO_EXECUTORCH], + StageType.TO_EXECUTORCH: [StageType.SERIALIZE], + StageType.SERIALIZE: [], + } + + # Current stage type + self.cur: Optional[StageType] = None + + # Reference output from eager mode + self.reference_output = None + + # Quantization scale from eager mode + self.quantization_scale: Optional[float] = None + + # Artifact output from stage + self.stage_output = None + + @staticmethod + def default_stage_classes() -> Dict[StageType, Type]: + """ + Returns a map of StageType to default Stage implementation. + """ + return { + StageType.EXPORT: Export, + StageType.QUANTIZE: Quantize, + StageType.PARTITION: Partition, + StageType.RUN_PASSES: RunPasses, + StageType.SERIALIZE: Serialize, + StageType.TO_EDGE: ToEdge, + StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower, + StageType.TO_EXECUTORCH: ToExecutorch, + } + + def _get_default_stage(self, stage_type: StageType, *args, **kwargs) -> Stage: + stage_class = self.stage_classes.get(stage_type) + if stage_class is None: + raise RuntimeError( + f"Attempted to instantiate a default implementation for stage {stage_type} but no default class was registered." + ) + return stage_class(*args, **kwargs) + + def generate_random_inputs(self): + # Get shapes of inputs + input_shapes = [] + if self.dynamic_shapes is None: + for tensor_arg in self.example_inputs: + assert isinstance(tensor_arg, torch.Tensor) + input_shapes.append(tensor_arg.shape) + else: + # Random shapes depending on dynamic shape constraint + dim_name_to_size = {} + for arg_idx in range(len(self.example_inputs)): + assert isinstance(self.example_inputs[arg_idx], torch.Tensor) + ex_shape = list(self.example_inputs[arg_idx].shape) + dynamic_dim_spec = self.dynamic_shapes[arg_idx] + for dim_idx, dim_spec in dynamic_dim_spec.items(): + assert dim_idx < len(ex_shape) + if isinstance(dim_spec, torch.export.dynamic_shapes._DerivedDim): + # derived dims are of the form {0: 2 * torch.export.Dim() // 2} + # The root contains the min/max of the export dim and fn contains + # the function to compute the derived dim. + dim_spec = dim_spec.root + fn = dim_spec.fn + elif isinstance(dim_spec, torch.export.dynamic_shapes._Dim): + # Not derived dim so fn is just itself + def fn(x): + return x + + else: + raise RuntimeError( + f"Expected Dynamic Dims to be of type _DerivedDim or _Dim but got {type(dim_spec)}" + ) + dim_name = dim_spec.__name__ + if dim_name not in dim_name_to_size: + upper_bound = min( + dim_spec.max, 1000 + ) # unbounded int max is too large + lower_bound = ( + dim_spec.min if dim_spec.min >= 2 else 1 + ) # 0/1 specialization means dim_spec.min can never be 1 + dim_name_to_size[dim_name] = fn( + random.randint(lower_bound, upper_bound) + ) + ex_shape[dim_idx] = dim_name_to_size[dim_spec.__name__] + input_shapes.append(torch.Size(ex_shape)) + # create random tensor inputs with the shapes given above: + random_inputs = [] + for arg_idx in range(len(self.example_inputs)): + memFormat = get_memory_format( + list(self.example_inputs[arg_idx].dim_order()) + ) + random_inputs.append( + torch.randn(input_shapes[arg_idx]) + .to(dtype=self.example_inputs[arg_idx].dtype) + .to(memory_format=memFormat) + ) + + yield tuple(random_inputs) + + def _pre(self, stage): + stage_type = stage.stage_type() + assert stage_type in self.stages and not self.stages[stage_type] + + last_artifact = self.original_module + if self.cur: + assert self.cur in self.pipeline, f"Invalid state: {self.cur}" + allowed_next_stages = self.pipeline[self.cur] + assert ( + stage_type in allowed_next_stages + ), f"Invalid next stage: {stage_type}" + last_artifact = self.get_artifact() + self.cur = stage_type + return last_artifact + + def _post(self, stage): + stage_type = stage.stage_type() + assert stage_type in self.stages + self.stages[stage_type] = stage + + def _run_stage(self, stage_instance, inputs=None): + assert isinstance(stage_instance, Stage) + prev_stage_artifact = self._pre(stage_instance) + stage_instance.run(prev_stage_artifact, inputs=inputs) + self._post(stage_instance) + return self + + # Stages + def quantize(self, quantize_stage: Optional[Quantize] = None): + return self._run_stage( + quantize_stage or self._get_default_stage(StageType.QUANTIZE), + self.example_inputs, + ) + + def export(self, export_stage: Optional[Export] = None): + return self._run_stage( + export_stage + or self._get_default_stage( + StageType.EXPORT, dynamic_shapes=self.dynamic_shapes + ), + self.example_inputs, + ) + + def to_edge(self, to_edge_stage: Optional[ToEdge] = None): + if not to_edge_stage: + to_edge_stage = self._get_default_stage(StageType.TO_EDGE) + res = self._run_stage(to_edge_stage) + return res + + def to_edge_transform_and_lower( + self, to_edge_and_transform_stage: Optional[ToEdgeTransformAndLower] = None + ): + return self._run_stage( + to_edge_and_transform_stage + or self._get_default_stage(StageType.TO_EDGE_TRANSFORM_AND_LOWER) + ) + + def run_passes(self, run_passes_stage: Optional[RunPasses] = None): + return self._run_stage( + run_passes_stage or self._get_default_stage(StageType.RUN_PASSES) + ) + + def partition(self, partition_stage: Optional[Partition] = None): + return self._run_stage( + partition_stage or self._get_default_stage(StageType.PARTITION) + ) + + def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] = None): + return self._run_stage( + to_executorch_stage or self._get_default_stage(StageType.TO_EXECUTORCH) + ) + + def serialize(self, serialize_stage: Optional[Serialize] = None): + return self._run_stage( + serialize_stage or self._get_default_stage(StageType.SERIALIZE) + ) + + # Util functions + def dump_artifact(self, path: Optional[str] = None, stage: Optional[str] = None): + stage = stage or self.cur + self.stages[stage].dump_artifact(path) + return self + + def get_artifact(self, stage: Optional[StageType] = None): + stage = stage or self.cur + return self.stages[stage].artifact + + def check(self, input: List[str]): + for key in input: + FileCheck().check(key).run(self.stages[self.cur].graph_module.code) + return self + + def check_not(self, input: List[str]): + for key in input: + FileCheck().check_not(key).run(self.stages[self.cur].graph_module.code) + return self + + def check_count(self, input: Dict[Any, int]): + # TODO target checks similar to checkGraphModuleNodes() + for key, count in input.items(): + FileCheck().check_count(key, count, exactly=True).run( + self.stages[self.cur].graph_module.code + ) + return self + + def check_node_count(self, input: Dict[Any, int]): + # Count the occurances of each target in the graph. + target_ops = [ + node.target + for node in self.stages[self.cur].graph_module.graph.nodes + if node.op == "call_function" + ] + op_counts = Counter(target_ops) + + for key, count in input.items(): + if count != op_counts[key]: + print(f"Nodes: {op_counts}") + raise AssertionError( + f"Expected {count} {key} nodes but found {op_counts[key]}." + ) + + return self + + def visualize( + self, reuse_server: bool = True, stage: Optional[StageType] = None, **kwargs + ): + # import here to avoid importing model_explorer when it is not needed which is most of the time. + from executorch.devtools.visualization import visualize + + visualize(self.get_artifact(stage), reuse_server=reuse_server, **kwargs) + return self + + def run_method_and_compare_outputs( + self, + stage: Optional[StageType] = None, + inputs: Optional[Tuple[torch.Tensor]] = None, + num_runs=1, + atol=1e-03, + rtol=1e-03, + qtol=0, + ): + number_of_runs = 1 if inputs is not None else num_runs + reference_stage = self.stages[StageType.EXPORT] + + stage = stage or self.cur + + print(f"Comparing Stage {stage} with Stage {reference_stage}") + for run_iteration in range(number_of_runs): + inputs_to_run = inputs if inputs else next(self.generate_random_inputs()) + input_shapes = [generated_input.shape for generated_input in inputs_to_run] + print(f"Run {run_iteration} with input shapes: {input_shapes}") + + # Reference output (and quantization scale) + ( + reference_output, + quantization_scale, + ) = self._calculate_reference_output( + reference_stage.artifact, inputs_to_run + ) + + # Output from running artifact at stage + stage_output = self.stages[stage].run_artifact(inputs_to_run) + self._compare_outputs( + reference_output, stage_output, quantization_scale, atol, rtol, qtol + ) + + return self + + @staticmethod + def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): + """ + Helper testing function that asserts that the model output and the reference output + are equal with some tolerance. Due to numerical differences between eager mode and + the XNNPACK's backend, we relax the detal such that absolute tolerance is 1e-3. and + relative tolerance is 1e-3. In the event that the computation was quantized, we + further relax the tolerance to one quantized step (equal to the quantization scale). + This allows the quantized value to differ by 1 between the reference and model output. + """ + + assert len(model_output) == len(ref_output) + + for i in range(len(model_output)): + model = model_output[i] + ref = ref_output[i] + assert ( + ref.shape == model.shape + ), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}" + if model.dtype == torch.bool: + assert torch.equal(model, ref), ( + f"Output {i} (bool tensor) does not match reference output.\n" + f"\tShape: {model.shape}\n" + f"\tMismatched count: {(model != ref).sum().item()} / {model.numel()}\n" + ) + else: + assert torch.allclose( + model, + ref, + atol=atol, + rtol=rtol, + ), ( + f"Output {i} does not match reference output.\n" + f"\tGiven atol: {atol}, rtol: {rtol}.\n" + f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" + f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n" + f"\t-- Model vs. Reference --\n" + f"\t Numel: {model.numel()}, {ref.numel()}\n" + f"\tMedian: {model.median()}, {ref.median()}\n" + f"\t Mean: {model.mean()}, {ref.mean()}\n" + f"\t Max: {model.max()}, {ref.max()}\n" + f"\t Min: {model.min()}, {ref.min()}\n" + ) + + @staticmethod + def _compare_outputs( + reference_output, + stage_output, + quantization_scale=None, + atol=1e-03, + rtol=1e-03, + qtol=0, + ): + """ + Compares the original of the original nn module with the output of the generated artifact. + This requres calling run_method before calling compare_outputs. As that runs the generated + artifact on the sample inputs and sets the stage output to be compared against the reference. + """ + # Wrap both outputs as tuple, since executor output is always a tuple even if single tensor + if isinstance(reference_output, torch.Tensor): + reference_output = (reference_output,) + if isinstance(stage_output, torch.Tensor): + stage_output = (stage_output,) + + # If a qtol is provided and we found an dequantization node prior to the output, relax the + # atol by qtol quant units. + if quantization_scale is not None: + atol += quantization_scale * qtol + + Tester._assert_outputs_equal( + stage_output, + reference_output, + atol=atol, + rtol=rtol, + ) + + @staticmethod + def _calculate_reference_output( + program: ExportedProgram, inputs + ) -> Tuple[torch.Tensor, Optional[float]]: + """ + Execute the reference program and return the output. If the output comes from a dequantize node, + return the quantization scale as well. + """ + + # Locate the output node. + output_node = None + for node in program.graph.nodes: + if node.op == "output": + output_node = node + break + assert output_node is not None + + # Look for a dequantization node in the output node args. Returned values are found in the first + # argument of the output node. + dequant_node = None + for arg_node in output_node.args[0]: + if ( + arg_node.op == "call_function" + and arg_node.target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): + dequant_node = arg_node + break + + scale = None + if dequant_node is not None: + original_target = dequant_node.target + + # Replace the dequant node with shim to intercept the quantization parameters. + # It will be invoked when we evaluate the program to find the reference outputs. + def dequant_shim(*args): + nonlocal scale + scale = args[1] + result = original_target(*args) + return result + + dequant_node.target = dequant_shim + + output = program.module()(*inputs) + return output, scale diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 5c59f13fc24..dc8275bc099 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -56,52 +56,97 @@ TYPE_MAPPINGS: Dict[str, Any] = { "IMAGE_T": { 3: { + "double": "image3D", "float": "image3D", "half": "image3D", - "int": "iimage3D", - "uint": "uimage3D", + # integer dtypes "int8": "iimage3D", "uint8": "uimage3D", + "int16": "iimage3D", + "uint16": "uimage3D", + "int32": "iimage3D", + "uint32": "uimage3D", + "int64": "iimage3D", + "uint64": "uimage3D", + # common dtype aliases "bool": "uimage3D", + "int": "iimage3D", + "uint": "uimage3D", }, 2: { + "double": "image2D", "float": "image2D", "half": "image2D", - "int": "iimage2D", - "uint": "uimage2D", + # integer dtypes "int8": "iimage2D", "uint8": "uimage2D", + "int16": "iimage2D", + "uint16": "uimage2D", + "int32": "iimage2D", + "uint32": "uimage2D", + "int64": "iimage2D", + "uint64": "uimage2D", + # common dtype aliases "bool": "uimage2D", + "int": "iimage2D", + "uint": "uimage2D", }, }, "SAMPLER_T": { 3: { + "double": "sampler3D", "float": "sampler3D", "half": "sampler3D", - "int": "isampler3D", - "uint": "usampler3D", + # integer dtypes "int8": "isampler3D", "uint8": "usampler3D", + "int16": "isampler3D", + "uint16": "usampler3D", + "int32": "isampler3D", + "uint32": "usampler3D", + "int64": "isampler3D", + "uint64": "usampler3D", + # common dtype aliases "bool": "usampler3D", + "int": "isampler3D", + "uint": "usampler3D", }, 2: { + "double": "sampler2D", "float": "sampler2D", "half": "sampler2D", - "int": "isampler2D", - "uint": "usampler2D", + # integer dtypes "int8": "isampler2D", "uint8": "usampler2D", + "int16": "isampler2D", + "uint16": "usampler2D", + "int32": "isampler2D", + "uint32": "usampler2D", + "int64": "isampler2D", + "uint64": "usampler2D", + # common dtype aliases "bool": "usampler2D", + "int": "isampler2D", + "uint": "usampler2D", }, }, "IMAGE_FORMAT": { + "double": "rgba32f", "float": "rgba32f", "half": "rgba16f", - "int": "rgba32i", - "uint": "rgba32ui", + # integer dtypes "int8": "rgba8i", "uint8": "rgba8ui", + "int16": "rgba16i", + "uint16": "rgba16ui", + "int32": "rgba32i", + "uint32": "rgba32ui", + "int64": "rgba32i", + "uint64": "rgba32ui", + # common dtype aliases "bool": "rgba8ui", + "int": "rgba32i", + "uint": "rgba32ui", }, } @@ -118,10 +163,18 @@ def define_variable(name: str) -> str: def buffer_scalar_type(dtype: str) -> str: if dtype == "half": return "float16_t" - elif dtype[-1] == "8": - return dtype + "_t" + elif dtype == "float": + return "float" + elif dtype == "double": + return "float64_t" + # integer dtype alias conversion elif dtype == "bool": return "uint8_t" + # we don't want to append _t for int32 or uint32 as int is already 32bit + elif dtype == "int32" or dtype == "uint32": + return "int" if dtype == "int32" else "uint" + elif dtype[-1].isdigit(): + return dtype + "_t" return dtype @@ -129,22 +182,28 @@ def buffer_gvec_type(dtype: str, n: int) -> str: if n == 1: return buffer_scalar_type(dtype) - if dtype == "float": - return f"vec{n}" - if dtype == "uint": - return f"uvec{n}" - elif dtype == "half": - return f"f16vec{n}" - elif dtype == "int": - return f"ivec{n}" - elif dtype == "int8": - return f"i8vec{n}" - elif dtype == "uint8": - return f"u8vec{n}" - elif dtype == "bool": - return f"u8vec{n}" - - raise AssertionError(f"Invalid dtype: {dtype}") + dtype_map = { + "half": f"f16vec{n}", + "float": f"vec{n}", + "double": f"vec{n}", # No 64bit image format support in GLSL + "int8": f"i8vec{n}", + "uint8": f"u8vec{n}", + "int16": f"i16vec{n}", + "uint16": f"u16vec{n}", + "int32": f"ivec{n}", + "int": f"ivec{n}", + "uint32": f"uvec{n}", + "uint": f"uvec{n}", + "int64": f"ivec{n}", # No 64bit image format support in GLSL + "uint64": f"uvec{n}", # No 64bit image format support in GLSL + "bool": f"u8vec{n}", + } + + vector_type = dtype_map.get(dtype) + if vector_type is None: + raise AssertionError(f"Invalid dtype: {dtype}") + + return vector_type def texel_type(dtype: str) -> str: @@ -365,15 +424,22 @@ def define_required_extensions(dtypes: Union[str, List[str]]): if dtype == "half": nbit = "16bit" glsl_type = "float16" - elif dtype == "int16" or dtype == "uint16": - nbit = "16bit" - glsl_type = "int16" - elif dtype == "int8" or dtype == "uint8" or dtype == "bool": + elif dtype == "double": + # We only need to allow float64_t type usage + glsl_type = "float64" + elif dtype in ["int8", "uint8", "bool"]: nbit = "8bit" glsl_type = "int8" + elif dtype in ["int16", "uint16"]: + nbit = "16bit" + glsl_type = "int16" + elif dtype in ["int64", "uint64"]: + # We only need to allow int64_t and uint64_t type usage + glsl_type = "int64" - if nbit is not None and glsl_type is not None: + if nbit is not None: out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n" + if glsl_type is not None: out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n" return out_str @@ -629,6 +695,10 @@ def generateVariantCombinations( elif "VALUE" in value: suffix = value.get("SUFFIX", value["VALUE"]) + if value["VALUE"] in ["int", "uint"]: + raise ValueError( + f"Use int32 or uint32 instead of {value['VALUE']}" + ) param_values.append((param_name, suffix, value["VALUE"])) else: @@ -775,85 +845,88 @@ def generateSPV( # noqa: C901 ) -> Dict[str, str]: output_file_map = {} - def process_shader(shader_paths_pair): + def generate_src_file(shader_paths_pair): + # Extract components from the input tuple + # name of .glsl, .glslh, or .h to be generated src_file_name = shader_paths_pair[0] - + # path of template file used for codegen src_file_fullpath = shader_paths_pair[1][0] + # args to be used for codegen codegen_params = shader_paths_pair[1][1] - requires_codegen = True - if "YAML_SRC_FULLPATH" not in codegen_params: - requires_codegen = False - + # Assume that generated files will have the same file extension as the + # source template file. src_file_ext = extract_extension(src_file_fullpath) out_file_ext = src_file_ext - compile_spv = False - if out_file_ext == "glsl": - compile_spv = True + # Construct generated file name + gen_out_path = os.path.join(output_dir, f"{src_file_name}.{out_file_ext}") + # Construct path of cached generated file + cached_gen_out_path = os.path.join( + cache_dir, f"{src_file_name}.{out_file_ext}" + ) + + # Execute codegen to generate the output file + with codecs.open(src_file_fullpath, "r", encoding="utf-8") as input_file: + input_text = input_file.read() + input_text = self.maybe_replace_u16vecn(input_text) + output_text = preprocess(input_text, codegen_params) + + with codecs.open(gen_out_path, "w", encoding="utf-8") as output_file: + output_file.write(output_text) + + if cache_dir is not None: + # Store the generated file in the cache for SPIR-V compilation + shutil.copyfile(gen_out_path, cached_gen_out_path) + + def compile_spirv(shader_paths_pair): + # Extract components from the input tuple + # name of generated .glsl, .glslh, or .h + src_file_name = shader_paths_pair[0] + # path of template file used for codegen + src_file_fullpath = shader_paths_pair[1][0] + # Assume that generated files will have the same file extension as the + # source template file. + src_file_ext = extract_extension(src_file_fullpath) + out_file_ext = src_file_ext + + # Infer name of generated file (created by generate_src_file) gen_out_path = os.path.join(output_dir, f"{src_file_name}.{out_file_ext}") - spv_out_path = None - if compile_spv: - spv_out_path = os.path.join(output_dir, f"{src_file_name}.spv") + + # Only proceed if GLSL -> SPIR-V compilation is required for this file + if out_file_ext != "glsl": + return (None, gen_out_path) + + # Construct name of SPIR-V file to be compiled, if needed + spv_out_path = os.path.join(output_dir, f"{src_file_name}.spv") if cache_dir is not None: - cached_src_file_fullpath = os.path.join( - cache_dir, os.path.basename(src_file_fullpath) + ".t" - ) - cached_codegen_yaml = os.path.join(cache_dir, f"{src_file_name}.yaml") + # Construct the file names of cached SPIR-V file to check if they exist + # in the cache. cached_gen_out_path = os.path.join( cache_dir, f"{src_file_name}.{out_file_ext}" ) cached_spv_out_path = os.path.join(cache_dir, f"{src_file_name}.spv") + + # Only use cached artifacts if all of the expected artifacts are present if ( not force_rebuild - and os.path.exists(cached_src_file_fullpath) and os.path.exists(cached_gen_out_path) - and (not requires_codegen or os.path.exists(cached_codegen_yaml)) - and (not compile_spv or os.path.exists(cached_spv_out_path)) + and os.path.exists(cached_spv_out_path) ): - current_checksum = self.get_md5_checksum(src_file_fullpath) - cached_checksum = self.get_md5_checksum(cached_src_file_fullpath) - yaml_unchanged = True - if requires_codegen: - yaml_file_fullpath = codegen_params["YAML_SRC_FULLPATH"] - current_yaml_checksum = self.get_md5_checksum( - yaml_file_fullpath - ) - cached_yaml_checksum = self.get_md5_checksum( - cached_codegen_yaml - ) - yaml_unchanged = current_yaml_checksum == cached_yaml_checksum - # If the cached source GLSL template is the same as the current GLSL - # source file, then assume that the generated GLSL and SPIR-V will - # not have changed. In that case, just copy over the GLSL and SPIR-V - # files from the cache. - if yaml_unchanged and current_checksum == cached_checksum: - shutil.copyfile(cached_gen_out_path, gen_out_path) - if compile_spv: - shutil.copyfile(cached_spv_out_path, spv_out_path) + current_checksum = self.get_md5_checksum(gen_out_path) + cached_checksum = self.get_md5_checksum(cached_gen_out_path) + # If the cached generated GLSL file is the same as the current GLSL + # generated file, then assume that the generated GLSL and SPIR-V + # will not have changed. In that case, just copy over the GLSL and + # SPIR-V files from the cache and return. + if current_checksum == cached_checksum: + shutil.copyfile(cached_spv_out_path, spv_out_path) return (spv_out_path, gen_out_path) - with codecs.open(src_file_fullpath, "r", encoding="utf-8") as input_file: - input_text = input_file.read() - input_text = self.maybe_replace_u16vecn(input_text) - output_text = preprocess(input_text, codegen_params) - - with codecs.open(gen_out_path, "w", encoding="utf-8") as output_file: - output_file.write(output_text) - - if cache_dir is not None: - # Otherwise, store the generated GLSL files in the cache - shutil.copyfile(gen_out_path, cached_gen_out_path) - # If a YAML file was used to configure codegen, cache it as well - if requires_codegen: - yaml_file_fullpath = codegen_params["YAML_SRC_FULLPATH"] - shutil.copyfile(yaml_file_fullpath, cached_codegen_yaml) - - # If no GLSL compiler is specified, or the source file is not a GLSL shader - # then only write out the generated GLSL shaders. - if compile_spv and self.glslc_path is not None: + # Only proceed if a GLSL compiler was specified + if self.glslc_path is not None: cmd_base = [ self.glslc_path, "-fshader-stage=compute", @@ -891,10 +964,15 @@ def process_shader(shader_paths_pair): return (spv_out_path, gen_out_path) - # Parallelize shader compilation as much as possible to optimize build time. + # Run codegen serially to ensure that all .glsl, .glslh, and .h files are up to + # date before compilation + for generated_file_tuple in self.output_file_map.items(): + generate_src_file(generated_file_tuple) + + # Parallelize SPIR-V compilation to optimize build time with ThreadPool(os.cpu_count()) as pool: for spv_out_path, glsl_out_path in pool.map( - process_shader, self.output_file_map.items() + compile_spirv, self.output_file_map.items() ): output_file_map[spv_out_path] = glsl_out_path diff --git a/backends/vulkan/runtime/graph/ops/glsl/arange.yaml b/backends/vulkan/runtime/graph/ops/glsl/arange.yaml index e3df8bf73a1..37b2027db85 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/arange.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/arange.yaml @@ -7,13 +7,13 @@ arange: parameter_names_with_default_values: NDIM: 3 - DTYPE: int + DTYPE: int32 STORAGE: texture3d PACKING: C_packed generate_variant_forall: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: arange diff --git a/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml index eddddec0d8d..b1e16dec8d6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml @@ -13,6 +13,6 @@ avg_pool2d: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: avg_pool2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index c0efdd81eb9..accfcf53599 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -17,7 +17,7 @@ binary_op: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: binary_add - NAME: binary_sub diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml index 9abd9c1deac..e8bb86dbf6a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml @@ -12,8 +12,9 @@ buffer_to_buffer: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: buffer_to_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml index e48eab63a64..679e686dc2f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml @@ -13,9 +13,10 @@ buffer_to_nchw: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: buffer_to_nchw - NAME: buffer_to_nchw_no_pc diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh new file mode 100644 index 00000000000..66620e9b174 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CHOOSE_QPARAMS_GLSLH +#define CHOOSE_QPARAMS_GLSLH + +// equivalent of the eps defined in the cpu implementation +#define SMALL_SCALE_THRESHOLD 6.1e-5 + +// Calculate scale and zero point from min and max values +void calculate_scale_and_zero_point( + float min_val, + float max_val, + int qmin, + int qmax, + out float scale_val, + out int zero_point_val) { + // ensure we have zero included in our range + min_val = min(min_val, 0.0); + max_val = max(max_val, 0.0); + + scale_val = (max_val - min_val) / float(qmax - qmin); + + // Handle zero or very small scale + if (scale_val == 0.0 || isinf(1.0 / scale_val)) { + scale_val = 0.1; + } + + // Cut off small scale + if (scale_val < SMALL_SCALE_THRESHOLD) { + float org_scale = scale_val; + scale_val = SMALL_SCALE_THRESHOLD; + + // Adjust min and max based on new scale + if (min_val == 0.0) { + max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin); + } else if (max_val == 0.0) { + min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min_val *= amplifier; + max_val *= amplifier; + } + } + + // Calculate zero point + float zero_point_from_min = float(qmin) - min_val / scale_val; + float zero_point_from_max = float(qmax) - max_val / scale_val; + float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale_val); + float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale_val); + float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Nudge zero point to integer + if (initial_zero_point < float(qmin)) { + zero_point_val = qmin; + } else if (initial_zero_point > float(qmax)) { + zero_point_val = qmax; + } else { + zero_point_val = int(round(initial_zero_point)); + } +} + +#endif // CHOOSE_QPARAMS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl new file mode 100644 index 00000000000..dcbfe493f34 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl @@ -0,0 +1,278 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} + +#define ${MODE} + +${define_active_storage_type("buffer")} +${define_required_extensions(IN_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")} +${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + int quant_min; + int quant_max; + }; +$else: + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "ivec4", "t_in_sizes")} +${layout_declare_ubo(B, "ivec4", "t_in_strides")} +${layout_declare_ubo(B, "ivec4", "t_scale_sizes")} +${layout_declare_ubo(B, "ivec4", "t_scale_strides")} +${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")} +${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")} + +#include "indexing_utils.h" +#include "choose_qparams.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#define NWORKERS 64 + +// Shared memory for reduction - must match local work group size +shared float shared_min[NWORKERS]; +shared float shared_max[NWORKERS]; + +/* + * QUANTIZATION PARAMETER COMPUTATION SHADER (BUFFER STORAGE) + * + * This shader computes quantization parameters (scale and zero_point) for converting + * floating-point tensors to n-bit integer representations while preserving the + * original data range as much as possible. + * + * ALGORITHM: + * 1. Find global min/max values across tensor elements using parallel reduction + * 2. Use tree reduction with shared memory for efficient min/max computation + * 3. Calculate scale = (max - min) / (quant_max - quant_min) + * 4. Calculate zero_point to map floating-point zero to integer value + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {1, 1, 1} (single workgroup processes entire tensor) + * - Local WG Size: {64, 1, 1} (matches NWORKERS for shared memory) + * - Per-Token Mode: + * - Global WG Size: {num_tokens, 1, 1} (one workgroup per token) + * - Local WG Size: {64, 1, 1} (matches NWORKERS for shared memory) + * + * SUPPORTED CONFIGURATIONS: + * - Buffer Storage: Uses simple linear indexing through buffer elements + * - No axis mapping or packing considerations - processes elements sequentially + * - Works with any tensor layout since it accesses buffer data linearly + * + * TREE REDUCTION VISUALIZATION FOR MIN/MAX FINDING: + * For 8 threads processing elements [10, 1, 8, 1, 0, 2, 3, 5]: + * + * Initial shared_min/shared_max arrays populated by each thread: + * shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + * shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + * Thread: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + * + * Stride 1 (compare pairs, keep min/max): + * shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) + * shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) + * Active: | 0 | | 2 | | 4 | | 6 | | + * + * Stride 2 (compare pairs, keep min/max): + * shared_min: | 0 | | | | 0 | | | | (min(1,1), min(0,3)) + * shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) + * Active: | 0 | | | | 4 | | | | + * + * Stride 4 (final comparison): + * shared_min: | 0 | | | | | | | | (min(0,0) = 0) + * shared_max: | 10 | | | | | | | | (max(10,5) = 10) + * Active: | 0 | | | | | | | | + * + * Final result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) + * + * PER-TENSOR QUANTIZATION: + * - Single workgroup processes entire tensor with strided access + * - Each thread processes elements [thread_id, thread_id + 64, thread_id + 128, ...] + * - Tree reduction combines all thread results into global min/max + * - Output: Single scale and zero_point values + * + * PER-TOKEN QUANTIZATION: + * - Multiple workgroups, each processing one token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Each workgroup finds min/max within its assigned token + * - Output: Array of scale and zero_point values (one per token) + */ + +#ifdef per_tensor + +void choose_qparams_per_tensor() { + uint global_id = gl_GlobalInvocationID.x; + uint local_id = gl_LocalInvocationID.x; + uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; + + uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); + + // Each thread processes multiple elements with stride + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + for (uint i = global_id; i < total_elements; i += total_threads) { + float val = t_in[i]; + if (!isnan(val) && !isinf(val)) { + if (!found_valid) { + thread_min = val; + thread_max = val; + found_valid = true; + } else { + thread_min = min(thread_min, val); + thread_max = max(thread_max, val); + } + } + } + + // Intra-group reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Final result calculation (single workgroup only) + if (local_id == 0) { + float global_min = shared_min[0]; + float global_max = shared_max[0]; + + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + + t_scale[0] = scale_val; + t_zero_point[0] = zero_point_val; + } +} + +#else + +void choose_qparams_per_token() { + uint global_id = gl_GlobalInvocationID.x; + uint local_id = gl_LocalInvocationID.x; + uint group_id = gl_WorkGroupID.x; + uint total_workgroups = gl_NumWorkGroups.x; + + uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); + uint token_size = total_elements / uint(num_tokens); + + // Calculate how many tokens each workgroup should process + // This handles the case where we have more tokens than workgroups + uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups; + + // Calculate which tokens this workgroup is responsible for + uint start_token = group_id * tokens_per_workgroup; + uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens)); + + // Early exit if this workgroup has no tokens to process + if (start_token >= uint(num_tokens)) { + return; + } + + // Process each token assigned to this workgroup + for (uint token_id = start_token; token_id < end_token; token_id++) { + // Calculate the start and end indices for this token + uint token_start = token_id * token_size; + uint token_end = token_start + token_size; + + // Each thread processes multiple elements within the token with stride + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + // Process elements within this token only + for (uint i = token_start + local_id; i < token_end; i += gl_WorkGroupSize.x) { + float val = t_in[i]; + if (!isnan(val) && !isinf(val)) { + if (!found_valid) { + thread_min = val; + thread_max = val; + found_valid = true; + } else { + thread_min = min(thread_min, val); + thread_max = max(thread_max, val); + } + } + } + + // Intra-group reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Final calculation for this token + if (local_id == 0) { + float token_min = shared_min[0]; + float token_max = shared_max[0]; + + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + + t_scale[token_id] = scale_val; + t_zero_point[token_id] = zero_point_val; + } + + // Synchronize before processing next token + barrier(); + } +} + +#endif + +void main() { + choose_qparams_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml new file mode 100644 index 00000000000..c37039f68e9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml @@ -0,0 +1,12 @@ +choose_qparams_buffer: + parameter_names_with_default_values: + IN_DTYPE: float + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: float + shader_variants: + - NAME: choose_qparams_tensor_buffer + MODE: per_tensor + - NAME: choose_qparams_per_token_asymmetric_buffer + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl new file mode 100644 index 00000000000..282f1de170a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl @@ -0,0 +1,398 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} + +#define ${MODE} + +${define_active_storage_type("texture3d")} +${define_required_extensions(IN_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_scale", "float", "texture3d")} +${layout_declare_tensor(B, "w", "t_zero_point", "int", "texture3d")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + int quant_min; + int quant_max; + }; +$else: + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "ivec3", "t_in_limits")} +${layout_declare_ubo(B, "ivec3", "t_scale_limits")} +${layout_declare_ubo(B, "ivec3", "t_zero_point_limits")} + +#include "indexing_utils.h" +#include "choose_qparams.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#define NWORKERS 64 + +// Shared memory for reduction - must match local work group size +shared float shared_min[NWORKERS]; +shared float shared_max[NWORKERS]; + +/* + * QUANTIZATION PARAMETER COMPUTATION SHADER (TEXTURE STORAGE) + * + * This shader computes quantization parameters (scale and zero_point) for converting + * floating-point tensors to n-bit integer representations while preserving the + * original data range as much as possible. + * + * ALGORITHM: + * 1. Find global min/max values across tensor elements using parallel reduction + * 2. Use tree reduction with shared memory for efficient min/max computation + * 3. Calculate scale = (max - min) / (quant_max - quant_min) + * 4. Calculate zero_point to map floating-point zero to integer value + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: Default (typically {num_elements, 1, 1}) + * - Local WG Size: Default (typically {64, 1, 1}) + * - Per-Token Mode: + * - Global WG Size: Default (typically based on tensor dimensions) + * - Local WG Size: Default (typically {64, 1, 1}, or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Texture Storage: Uses 3D texture indexing with linear texel iteration + * - Assumes width-packed layout (packed_dim = 0) in current implementation + * - Handles texel padding for non-multiple-of-4 tensor dimensions + * - Note: Axis mapping support depends on indexing utilities + * + * TREE REDUCTION VISUALIZATION FOR MIN/MAX FINDING: + * For 8 threads processing elements [10, 1, 8, 1, 0, 2, 3, 5]: + * + * Initial shared_min/shared_max arrays populated by each thread: + * shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + * shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + * Thread: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + * + * Stride 1 (compare pairs, keep min/max): + * shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) + * shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) + * Active: | 0 | | 2 | | 4 | | 6 | | + * + * Stride 2 (compare pairs, keep min/max): + * shared_min: | 0 | | | | 0 | | | | (min(1,1), min(0,3)) + * shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) + * Active: | 0 | | | | 4 | | | | + * + * Stride 4 (final comparison): + * shared_min: | 0 | | | | | | | | (min(0,0) = 0) + * shared_max: | 10 | | | | | | | | (max(10,5) = 10) + * Active: | 0 | | | | | | | | + * + * Final result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) + * + * PER-TENSOR QUANTIZATION: + * - Single workgroup processes entire tensor + * - Each thread processes multiple texels with stride + * - Thread 0: texels [0, 64, 128, ...] -> elements [0-3, 256-259, 512-515, ...] + * - Thread 1: texels [1, 65, 129, ...] -> elements [4-7, 260-263, 516-519, ...] + * - Tree reduction combines all thread results into global min/max + * - Output: Single scale and zero_point values + * + * PER-TOKEN QUANTIZATION: + * - Multiple workgroups, each processing subset of tokens + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Each workgroup processes multiple tokens if num_tokens > num_workgroups + * - Within each token, threads process texels containing token elements + * - Output: Array of scale and zero_point values (one per token) + */ + +#ifdef per_tensor + +void choose_qparams_per_tensor() { + uint global_id = gl_GlobalInvocationID.x; + uint local_id = gl_LocalInvocationID.x; + uint group_id = gl_WorkGroupID.x; + uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; + + uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); + + // Each thread processes multiple texels with stride + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + // Process texels with stride across all threads + for (uint texel_idx = global_id; texel_idx < total_texels; texel_idx += total_threads) { + // Convert linear texel index to 3D coordinates + uint z = texel_idx / uint(t_in_limits.x * t_in_limits.y); + uint remainder = texel_idx % uint(t_in_limits.x * t_in_limits.y); + uint y = remainder / uint(t_in_limits.x); + uint x = remainder % uint(t_in_limits.x); + ivec3 texel_pos = ivec3(int(x), int(y), int(z)); + + FVEC4_T texel_data = load_texel(t_in, texel_pos); + + // For texture storage, we assume width-packed (packed_dim = 0) + // Calculate number of valid elements in this texel (handle padding) + int packed_dim = 0; // Width dimension is packed + ivec4 sizes = ivec4(t_in_limits, 1); // Convert limits to sizes format + ivec4 tensor_coord = to_tensor_idx(texel_pos, sizes, packed_dim); + + // Calculate total tensor elements to determine padding + int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4; + int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x + + tensor_coord.z * sizes.x * sizes.y; + int remaining_elements = total_elements - (linear_tensor_idx); + int valid_elements = min(4, remaining_elements); + + // Find min/max within this texel, considering only valid elements + if (valid_elements >= 1 && !isnan(texel_data.x) && !isinf(texel_data.x)) { + if (!found_valid) { + thread_min = texel_data.x; + thread_max = texel_data.x; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.x); + thread_max = max(thread_max, texel_data.x); + } + } + + if (valid_elements >= 2 && !isnan(texel_data.y) && !isinf(texel_data.y)) { + if (!found_valid) { + thread_min = texel_data.y; + thread_max = texel_data.y; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.y); + thread_max = max(thread_max, texel_data.y); + } + } + + if (valid_elements >= 3 && !isnan(texel_data.z) && !isinf(texel_data.z)) { + if (!found_valid) { + thread_min = texel_data.z; + thread_max = texel_data.z; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.z); + thread_max = max(thread_max, texel_data.z); + } + } + + if (valid_elements >= 4 && !isnan(texel_data.w) && !isinf(texel_data.w)) { + if (!found_valid) { + thread_min = texel_data.w; + thread_max = texel_data.w; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.w); + thread_max = max(thread_max, texel_data.w); + } + } + } + + // Intra-workgroup reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Final result calculation (single workgroup only for reliability) + if (local_id == 0 && group_id == 0) { + float global_min = shared_min[0]; + float global_max = shared_max[0]; + + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + + write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0)); + write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0)); + } +} + +#else + +void choose_qparams_per_token() { + // Each token is processed by multiple workgroups for parallel reduction + uint local_id = gl_LocalInvocationID.x; + uint group_id = gl_WorkGroupID.x; + uint total_workgroups = gl_NumWorkGroups.x; + + uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); + + // Calculate texels per token (assuming last dimension contains the token data) + // For per-token quantization, we assume tokens are along the last dimension + uint texels_per_token = total_texels / uint(num_tokens); + + // Calculate how many tokens each workgroup should process + uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups; + + // Calculate which tokens this workgroup is responsible for + uint start_token = group_id * tokens_per_workgroup; + uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens)); + + // Process each token assigned to this workgroup + for (uint token_id = start_token; token_id < end_token; token_id++) { + // Calculate the texel range for this token + uint token_start_texel = token_id * texels_per_token; + uint token_end_texel = token_start_texel + texels_per_token; + + // Each thread processes multiple texels within the token + float thread_min = 1.0/0.0; // +infinity + float thread_max = -1.0/0.0; // -infinity + bool found_valid = false; + + // Process texels within this token only + for (uint texel_idx = token_start_texel + local_id; texel_idx < token_end_texel; texel_idx += gl_WorkGroupSize.x) { + // Convert linear texel index to 3D coordinates + uint z = texel_idx / uint(t_in_limits.x * t_in_limits.y); + uint remainder = texel_idx % uint(t_in_limits.x * t_in_limits.y); + uint y = remainder / uint(t_in_limits.x); + uint x = remainder % uint(t_in_limits.x); + ivec3 texel_pos = ivec3(int(x), int(y), int(z)); + + FVEC4_T texel_data = load_texel(t_in, texel_pos); + + // For texture storage, we assume width-packed (packed_dim = 0) + // Calculate number of valid elements in this texel (handle padding) + int packed_dim = 0; // Width dimension is packed + ivec4 sizes = ivec4(t_in_limits, 1); // Convert limits to sizes format + ivec4 tensor_coord = to_tensor_idx(texel_pos, sizes, packed_dim); + + // Calculate total tensor elements to determine padding + int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4; + int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x + + tensor_coord.z * sizes.x * sizes.y; + int remaining_elements = total_elements - (linear_tensor_idx); + int valid_elements = min(4, remaining_elements); + + // Find min/max within this texel, considering only valid elements + if (valid_elements >= 1 && !isnan(texel_data.x) && !isinf(texel_data.x)) { + if (!found_valid) { + thread_min = texel_data.x; + thread_max = texel_data.x; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.x); + thread_max = max(thread_max, texel_data.x); + } + } + + if (valid_elements >= 2 && !isnan(texel_data.y) && !isinf(texel_data.y)) { + if (!found_valid) { + thread_min = texel_data.y; + thread_max = texel_data.y; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.y); + thread_max = max(thread_max, texel_data.y); + } + } + + if (valid_elements >= 3 && !isnan(texel_data.z) && !isinf(texel_data.z)) { + if (!found_valid) { + thread_min = texel_data.z; + thread_max = texel_data.z; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.z); + thread_max = max(thread_max, texel_data.z); + } + } + + if (valid_elements >= 4 && !isnan(texel_data.w) && !isinf(texel_data.w)) { + if (!found_valid) { + thread_min = texel_data.w; + thread_max = texel_data.w; + found_valid = true; + } else { + thread_min = min(thread_min, texel_data.w); + thread_max = max(thread_max, texel_data.w); + } + } + } + + // Intra-workgroup reduction using shared memory + shared_min[local_id] = thread_min; + shared_max[local_id] = thread_max; + barrier(); + + // Tree reduction within work group + for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + float other_min = shared_min[local_id + stride]; + float other_max = shared_max[local_id + stride]; + + // Handle infinity values properly + if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { + shared_min[local_id] = other_min; + } + if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { + shared_max[local_id] = other_max; + } + } + barrier(); + } + + // Final calculation for this token + if (local_id == 0) { + float token_min = shared_min[0]; + float token_max = shared_max[0]; + + float scale_val; + int zero_point_val; + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + + // Convert token_id to 3D coordinates for output texture + // Assuming output tensors have the same layout as input but with different dimensions + uint out_z = token_id / uint(t_scale_limits.x * t_scale_limits.y); + uint out_remainder = token_id % uint(t_scale_limits.x * t_scale_limits.y); + uint out_y = out_remainder / uint(t_scale_limits.x); + uint out_x = out_remainder % uint(t_scale_limits.x); + ivec3 out_pos = ivec3(int(out_x), int(out_y), int(out_z)); + + write_texel(t_scale, out_pos, vec4(scale_val, 0.0, 0.0, 0.0)); + write_texel(t_zero_point, out_pos, ivec4(zero_point_val, 0, 0, 0)); + } + + // Synchronize before processing next token + barrier(); + } +} + +#endif + +void main() { + choose_qparams_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml new file mode 100644 index 00000000000..f3961b87a0f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml @@ -0,0 +1,12 @@ +choose_qparams_texture: + parameter_names_with_default_values: + IN_DTYPE: float + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: float + shader_variants: + - NAME: choose_qparams_tensor_texture3d + MODE: per_tensor + - NAME: choose_qparams_per_token_asymmetric_texture3d + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml index 414bf8191b9..984d9a09d43 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml @@ -7,6 +7,6 @@ copy_channel_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: copy_channel_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml index 87df7bf9dc1..09f5ca36ea4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml @@ -7,7 +7,7 @@ copy_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 STORAGE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml index e872d64e3c3..6e55876cb28 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml @@ -7,6 +7,6 @@ copy_packed_dim_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: copy_packed_dim_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh b/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh new file mode 100644 index 00000000000..7194bebda35 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh @@ -0,0 +1,16 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef DEQUANTIZE_GLSLH +#define DEQUANTIZE_GLSLH + +OUT_T dequantize_val(IN_T qvalue, float scale_val, int zero_point_val) { + return OUT_T(float(int(qvalue) - zero_point_val) * scale_val); +} + +#endif // DEQUANTIZE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl new file mode 100644 index 00000000000..2a1f62719a0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl @@ -0,0 +1,183 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} + +#define ${MODE} + +${define_active_storage_type("buffer")} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$if MODE == "per_token": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "int", "out_numel")} +${layout_declare_ubo(B, "ivec4", "t_in_sizes")} +${layout_declare_ubo(B, "ivec4", "t_in_strides")} +${layout_declare_ubo(B, "ivec4", "t_out_sizes")} +${layout_declare_ubo(B, "ivec4", "t_out_strides")} + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} + +#include "dequantize.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); +const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); + +/* + * DEQUANTIZATION SHADER (BUFFER STORAGE) + * + * This shader converts n-bit integer tensor values back to floating-point representations + * using pre-computed quantization parameters (scale and zero_point). The dequantization + * reconstructs the original floating-point values from their discrete integer representations + * with minimal precision loss. + * + * ALGORITHM: + * 1. Load quantized integer value from buffer + * 2. Apply dequantization formula: value = (qvalue - zero_point) * scale + * 3. Store reconstructed floating-point value to output buffer + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) + * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) + * - Per-Token Mode: + * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) + * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Buffer Storage: Uses linear buffer indexing with stride-based tensor access + * - Per-Tensor: Supports any tensor layout through stride calculations and dimension ordering + * - Per-Token: Supports only width packed tensors (packed_dim = 0) and standard axis mapping + * - Scale/zero_point tensors: Must use buffer storage with width packing (packed_dim = 0) + * + * DEQUANTIZATION FORMULA VISUALIZATION: + * For integer range [quant_min, quant_max] mapped back to [min_val, max_val]: + * + * Integer Domain: Floating Point Domain: + * quant_min ──────────────► min_val + * │ │ + * │ scale = (max_val - min_val) / (quant_max - quant_min) + * │ zero_point = quant_min - round(min_val / scale) + * │ │ + * quant_max ──────────────► max_val + * + * Dequantization Process: + * Input: -103 (int8) + * Step 1: qvalue - zero_point = -103 - (-128) = 25 + * Step 2: result * scale = 25 * 0.1 = 2.5 + * Output: 2.5 (float) + * + * PER-TENSOR DEQUANTIZATION: + * - Single scale and zero_point values for entire tensor + * - All elements use same dequantization parameters + * - Parameters passed as push constants for efficiency + * - Formula: value = (qvalue - zero_point) * scale + * + * PER-TOKEN DEQUANTIZATION: + * - Separate scale and zero_point for each token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Parameters stored in buffer arrays indexed by token_id + * - Each thread calculates its token_id from tensor coordinates + * - Formula: value = (qvalue - zero_point[token_id]) * scale[token_id] + * + * Token ID calculation for element at tensor index (w, z, y, x): + * - 4D tensor: token_id = w * (sizes.z * sizes.y) + z * sizes.y + y + * - 3D tensor: token_id = z * sizes.y + y + * - 2D tensor: token_id = y + * - 1D tensor: token_id = 0 + */ + +#ifdef per_tensor + +void dequantize_per_tensor() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T qvalue = t_in[in_bufi]; + OUT_T value = dequantize_val(qvalue, scale, zero_point); + + t_out[out_bufi] = value; +} + +#else + +void dequantize_per_token() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T qvalue = t_in[in_bufi]; + + int token_idx = 0; + + if (t_out_sizes.w > 1) { + // 4D tensor + token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y; + } else if (t_out_sizes.z > 1) { + // 3D tensor + token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y; + } else if (t_out_sizes.y > 1) { + // 2D tensor + token_idx = out_tidx.y; + } + // For 1D tensor, token_idx remains 0 + + token_idx = min(token_idx, num_tokens - 1); + + OUT_T value = dequantize_val(qvalue, t_scale[token_idx], t_zero_point[token_idx]); + + t_out[out_bufi] = value; +} + +#endif + +void main() { + dequantize_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml new file mode 100644 index 00000000000..fb0d2ee61bf --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml @@ -0,0 +1,19 @@ +dequantize_buffer: + parameter_names_with_default_values: + IN_DTYPE: int32 + OUT_DTYPE: float + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + OUT_DTYPE: + - VALUE: half + - VALUE: float + - VALUE: double + shader_variants: + - NAME: dequantize_per_tensor_buffer + MODE: per_tensor + - NAME: dequantize_per_token_buffer + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl new file mode 100644 index 00000000000..801f4a2f6a2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl @@ -0,0 +1,196 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define IVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} + +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} +#define FVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} + +#define ${MODE} + +${define_active_storage_type("texture3d")} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$if MODE == "per_token": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "ivec3", "t_in_limits")} +${layout_declare_ubo(B, "ivec3", "t_out_limits")} + +#include "indexing_utils.h" +#include "dequantize.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * DEQUANTIZATION SHADER (TEXTURE STORAGE) + * + * This shader converts n-bit integer tensor values back to floating-point representations + * using pre-computed quantization parameters (scale and zero_point). The dequantization + * reconstructs the original floating-point values from their discrete integer representations + * with minimal precision loss. + * + * ALGORITHM: + * 1. Load quantized integer texel (4 values) from 3D texture + * 2. Apply dequantization formula to each component: value = (qvalue - zero_point) * scale + * 3. Store reconstructed floating-point texel to output texture + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing + * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) + * - Per-Token Mode: + * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing + * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Texture Storage: Uses 3D texture indexing with texel-based processing + * - Assumes width-packed layout (packed_dim = 0) for input/output textures + * - Handles texel padding for non-multiple-of-4 tensor dimensions + * - For per-token mode: scale/zero_point tensors must use buffer storage + * - Input/output textures: Must use standard axis mapping for per-token mode + * + * DEQUANTIZATION FORMULA VISUALIZATION: + * For integer range [quant_min, quant_max] mapped back to [min_val, max_val]: + * + * Integer Domain: Floating Point Domain: + * quant_min ──────────────► min_val + * │ │ + * │ scale = (max_val - min_val) / (quant_max - quant_min) + * │ zero_point = quant_min - round(min_val / scale) + * │ │ + * quant_max ──────────────► max_val + * + * Texel Dequantization Process: + * Input Texel: [-103, -128, -123, -96] (int4) + * Per-component dequantization with scale=0.1, zero_point=-128: + * Component 0: (-103 - (-128)) * 0.1 = 25 * 0.1 = 2.5 + * Component 1: (-128 - (-128)) * 0.1 = 0 * 0.1 = 0.0 + * Component 2: (-123 - (-128)) * 0.1 = 5 * 0.1 = 0.5 + * Component 3: (-96 - (-128)) * 0.1 = 32 * 0.1 = 3.2 + * Output Texel: [2.5, 0.0, 0.5, 3.2] (float4) + * + * PER-TENSOR DEQUANTIZATION: + * - Single scale and zero_point values for entire tensor + * - All texel components use same dequantization parameters + * - Parameters passed as push constants for efficiency + * - Each thread processes one texel (4 elements) independently + * - Formula: value[i] = (qvalue[i] - zero_point) * scale + * + * PER-TOKEN DEQUANTIZATION: + * - Separate scale and zero_point for each token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Parameters stored in buffer arrays indexed by token_id + * - Each thread calculates token_id from its 3D texture position + * - Scale/zero_point buffers accessed directly (not as textures) + * - Formula: value[i] = (qvalue[i] - zero_point[token_id]) * scale[token_id] + * + * Token ID calculation for texel at position (x, y, z): + * - 3D tensor: token_id = z * texture_height + y + * - 2D tensor: token_id = y + * - 1D tensor: token_id = 0 + */ + +#ifdef per_tensor + +void dequantize_per_tensor() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + // Skip if out of bounds + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + IVEC4_T intex = load_texel(t_in, pos); + FVEC4_T outtex; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale, zero_point); + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; + } + write_texel(t_out, pos, outtex); +} + +#else + +void dequantize_per_token() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + IVEC4_T intex = load_texel(t_in, pos); + + int token_idx = 0; + ivec3 dims = t_in_limits; + + if (dims.z > 1) { + // 3D tensor + token_idx = pos.z * dims.y + pos.y; + } else if (dims.y > 1) { + // 2D tensor + token_idx = pos.y; + } + // For 1D tensor, token_idx remains 0 + + token_idx = min(token_idx, num_tokens - 1); + + // Scale and zero_point are prepacked as buffers, so direct access + float scale_val = t_scale[token_idx]; + int zero_point_val = t_zero_point[token_idx]; + + FVEC4_T outtex; + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; + } + + write_texel(t_out, pos, outtex); +} + +#endif + +void main() { + dequantize_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml new file mode 100644 index 00000000000..7d19a543a03 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml @@ -0,0 +1,19 @@ +dequantize_texture: + parameter_names_with_default_values: + IN_DTYPE: int32 + OUT_DTYPE: float + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + OUT_DTYPE: + - VALUE: half + - VALUE: float + - VALUE: double + shader_variants: + - NAME: dequantize_per_tensor_texture3d + MODE: per_tensor + - NAME: dequantize_per_token_texture3d + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml b/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml index 5ffe37265b1..0e7b491c433 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml @@ -7,6 +7,6 @@ embedding: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: embedding diff --git a/backends/vulkan/runtime/graph/ops/glsl/flip.yaml b/backends/vulkan/runtime/graph/ops/glsl/flip.yaml index 646fd05e420..f5e7c874773 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/flip.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/flip.yaml @@ -6,8 +6,9 @@ flip: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: flip diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml index 804ce19bdb8..646d8f1be81 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml @@ -14,9 +14,10 @@ image_to_nchw: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: image_to_nchw_texture3d - NAME: image_to_nchw_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml index 5a6c525993e..abef2225cd9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml @@ -7,6 +7,6 @@ index_select: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: index_select diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml index 66cb7ec3f89..a306e3ce47d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml @@ -7,6 +7,6 @@ index_select_channel: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: index_select_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml index 486d710cf55..99e41a0ab6f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml @@ -13,9 +13,10 @@ nchw_to_buffer: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: nchw_to_buffer - NAME: nchw_to_buffer_no_pc diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl index 4674822ce6a..f3f604e10cd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -87,5 +87,9 @@ void main() { return; } - write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx)); + $if DTYPE == "double" and DTYPE == "int64": + VEC4_T texel = read_texel(tidx); + write_texel(t_out, lpos_to_pos(lpos, axis_map), texel); + $else: + write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx)); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml index 7e52ec10376..85119c8d508 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml @@ -14,9 +14,10 @@ nchw_to_image: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: nchw_to_image_texture3d - NAME: nchw_to_image_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml index e64e1bd260a..bfeaba2496b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml @@ -12,7 +12,7 @@ no_op: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 STORAGE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute.yaml index f678aeedf6e..a90ddcb41ce 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.yaml @@ -7,6 +7,6 @@ permute: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: permute diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize.glslh b/backends/vulkan/runtime/graph/ops/glsl/quantize.glslh new file mode 100644 index 00000000000..cde72e41ac7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize.glslh @@ -0,0 +1,25 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef QUANTIZE_GLSLH +#define QUANTIZE_GLSLH + +OUT_T quantize_val(IN_T value, float scale_val, int zero_point_val) { + float inv_scale = 1.0 / scale_val; + + float rounded_float = round(inv_scale * float(value)); + + int qvalue = zero_point_val + int(rounded_float); + + qvalue = max(qvalue, quant_min); + qvalue = min(qvalue, quant_max); + + return OUT_T(qvalue); +} + +#endif // QUANTIZE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl new file mode 100644 index 00000000000..ea0c2f7dce7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl @@ -0,0 +1,179 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} + +#define ${MODE} + +${define_active_storage_type("buffer")} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$if MODE == "per_token": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "int", "out_numel")} +${layout_declare_ubo(B, "ivec4", "t_in_sizes")} +${layout_declare_ubo(B, "ivec4", "t_in_strides")} +${layout_declare_ubo(B, "ivec4", "t_out_sizes")} +${layout_declare_ubo(B, "ivec4", "t_out_strides")} + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} + +#include "quantize.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); +const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); + +/* + * QUANTIZATION SHADER (BUFFER STORAGE) + * + * This shader converts floating-point tensor values to n-bit integer representations + * using pre-computed quantization parameters (scale and zero_point). The quantization + * maps floating-point values to a discrete integer range while preserving the + * original data distribution as much as possible. + * + * ALGORITHM: + * 1. Load floating-point input value from buffer + * 2. Apply quantization formula: qvalue = round(value / scale) + zero_point + * 3. Clamp result to [quant_min, quant_max] range + * 4. Store quantized integer value to output buffer + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) + * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) + * - Per-Token Mode: + * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) + * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Per-Tensor Config: Uses linear buffer indexing with stride-based tensor access + * - and supports any tensor layout through stride calculations and dimension ordering + * - Per-Token Config: Assumes width-packed layout (packed_dim = 0) + * - since that is how token index is calculated + * + * QUANTIZATION FORMULA VISUALIZATION: + * For input range [min_val, max_val] mapped to integer range [quant_min, quant_max]: + * + * Floating Point Domain: Integer Domain: + * min_val ────────────────► quant_min + * │ │ + * │ scale = (max_val - min_val) / (quant_max - quant_min) + * │ zero_point = quant_min - round(min_val / scale) + * │ │ + * max_val ────────────────► quant_max + * + * Quantization Process: + * Input: 2.5 (float) + * Step 1: value / scale = 2.5 / 0.1 = 25.0 + * Step 2: round(25.0) + zero_point = 25 + (-128) = -103 + * Step 3: clamp(-103, -128, 127) = -103 + * Output: -103 (int8) + * + * PER-TENSOR QUANTIZATION: + * - Single scale and zero_point values for entire tensor + * - All elements use same quantization parameters + * - Parameters passed as push constants for efficiency + * - Formula: qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max) + * + * PER-TOKEN QUANTIZATION: + * - Separate scale and zero_point for each token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Parameters stored in buffer arrays indexed by token_id + * - Each thread calculates its token_id from tensor coordinates + * - Formula: qvalue = clamp(round(value / scale[token_id]) + zero_point[token_id], quant_min, quant_max) + */ + +#ifdef per_tensor + +void quantize_per_tensor() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T value = t_in[in_bufi]; + OUT_T qvalue = quantize_val(value, scale, zero_point); + + t_out[out_bufi] = qvalue; +} + +#else + +void quantize_per_token() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T value = t_in[in_bufi]; + + int token_idx = 0; + + if (t_out_sizes.w > 1) { + // 4D tensor + token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y; + } else if (t_out_sizes.z > 1) { + // 3D tensor + token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y; + } else if (t_out_sizes.y > 1) { + // 2D tensor + token_idx = out_tidx.y; + } + // For 1D tensor, token_idx remains 0 + + token_idx = min(token_idx, num_tokens - 1); + + OUT_T qvalue = quantize_val(value, t_scale[token_idx], t_zero_point[token_idx]); + + t_out[out_bufi] = qvalue; +} + +#endif + +void main() { + quantize_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml new file mode 100644 index 00000000000..4d95d610314 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml @@ -0,0 +1,19 @@ +quantize_buffer: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: int32 + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: half + - VALUE: float + - VALUE: double + OUT_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + shader_variants: + - NAME: quantize_per_tensor_buffer + MODE: per_tensor + - NAME: quantize_per_token_buffer + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl new file mode 100644 index 00000000000..9ba7074f75b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl @@ -0,0 +1,184 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} + +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} +#define IVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} + +#define ${MODE} + +${define_active_storage_type("texture3d")} +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} + +$if MODE == "per_tensor": + layout(push_constant) uniform restrict Block { + float scale; + int zero_point; + int quant_min; + int quant_max; + }; +$if MODE == "per_token": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + int num_tokens; + int quant_min; + int quant_max; + }; + +${layout_declare_ubo(B, "ivec3", "t_in_limits")} +${layout_declare_ubo(B, "ivec3", "t_out_limits")} + +#include "indexing_utils.h" +#include "quantize.glslh" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * QUANTIZATION SHADER (TEXTURE STORAGE) + * + * This shader converts floating-point tensor values to n-bit integer representations + * using pre-computed quantization parameters (scale and zero_point). The quantization + * maps floating-point values to a discrete integer range while preserving the + * original data distribution as much as possible. + * + * ALGORITHM: + * 1. Load floating-point texel (4 values) from 3D texture + * 2. Apply quantization formula to each component: qvalue = round(value / scale) + zero_point + * 3. Clamp each result to [quant_min, quant_max] range + * 4. Store quantized integer texel to output texture + * + * WORKGROUP CONFIGURATION: + * - Per-Tensor Mode: + * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing + * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) + * - Per-Token Mode: + * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing + * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) + * + * SUPPORTED CONFIGURATIONS: + * - Texture Storage: Uses 3D texture indexing with texel-based processing + * - Assumes width-packed layout (packed_dim = 0) in current implementation + * - Handles texel padding for non-multiple-of-4 tensor dimensions + * - For per-token mode: scale/zero_point tensors must use buffer storage + * + * QUANTIZATION FORMULA VISUALIZATION: + * For input range [min_val, max_val] mapped to integer range [quant_min, quant_max]: + * + * Floating Point Domain: Integer Domain: + * min_val ────────────────► quant_min + * │ │ + * │ scale = (max_val - min_val) / (quant_max - quant_min) + * │ zero_point = quant_min - round(min_val / scale) + * │ │ + * max_val ────────────────► quant_max + * + * Texel Quantization Process: + * Input Texel: [2.5, -1.0, 0.5, 3.2] (float4) + * Per-component quantization with scale=0.1, zero_point=-128: + * Component 0: round(2.5 / 0.1) + (-128) = 25 + (-128) = -103 + * Component 1: round(-1.0 / 0.1) + (-128) = -10 + (-128) = -138 → clamp to -128 + * Component 2: round(0.5 / 0.1) + (-128) = 5 + (-128) = -123 + * Component 3: round(3.2 / 0.1) + (-128) = 32 + (-128) = -96 + * Output Texel: [-103, -128, -123, -96] (int4) + * + * PER-TENSOR QUANTIZATION: + * - Single scale and zero_point values for entire tensor + * - All texel components use same quantization parameters + * - Parameters passed as push constants for efficiency + * - Each thread processes one texel (4 elements) independently + * - Formula: qvalue[i] = clamp(round(value[i] / scale) + zero_point, quant_min, quant_max) + * + * PER-TOKEN QUANTIZATION: + * - Separate scale and zero_point for each token + * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) + * - Parameters stored in buffer arrays indexed by token_id + * - Each thread calculates token_id from its 3D texture position + * - Scale/zero_point buffers accessed directly (not as textures) + * - Formula: qvalue[i] = clamp(round(value[i] / scale[token_id]) + zero_point[token_id], quant_min, quant_max) + */ + +#ifdef per_tensor + +void quantize_per_tensor() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + FVEC4_T intex = load_texel(t_in, pos); + IVEC4_T outtex; + + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, scale, zero_point); + outtex[i] = qvalue; + } + write_texel(t_out, pos, outtex); +} + +#else + +void quantize_per_token() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) { + return; + } + + FVEC4_T intex = load_texel(t_in, pos); + + int token_idx = 0; + ivec3 dims = t_in_limits; + + if (dims.z > 1) { + // 3D tensor + token_idx = pos.z * dims.y + pos.y; + } else if (dims.y > 1) { + // 2D tensor + token_idx = pos.y; + } + // For 1D tensor, token_idx remains 0 + + token_idx = min(token_idx, num_tokens - 1); + + // Scale and zero_point are prepacked as buffers, so direct access + float scale_val = t_scale[token_idx]; + int zero_point_val = t_zero_point[token_idx]; + + IVEC4_T outtex; + [[unroll]] for (int i = 0; i < 4; ++i) { + IN_T value = IN_T(intex[i]); + OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); + outtex[i] = qvalue; + } + + write_texel(t_out, pos, outtex); +} + +#endif + +void main() { + quantize_${MODE}(); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml new file mode 100644 index 00000000000..65002ce26b6 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml @@ -0,0 +1,19 @@ +quantize_texture: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: int32 + MODE: per_tensor + generate_variant_forall: + IN_DTYPE: + - VALUE: half + - VALUE: float + - VALUE: double + OUT_DTYPE: + - VALUE: uint8 + - VALUE: int8 + - VALUE: int32 + shader_variants: + - NAME: quantize_per_tensor_texture3d + MODE: per_tensor + - NAME: quantize_per_token_texture3d + MODE: per_token diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml index 526980a0f41..f40d94142e1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml @@ -7,7 +7,7 @@ repeat: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 shader_variants: diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index f13393ce6c7..47f538aee6c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -15,9 +15,9 @@ unary_op: OPERATOR: abs(X) - NAME: clamp OPERATOR: clamp(X, A, B) - - NAME: clamp_int + - NAME: clamp_int32 OPERATOR: clamp(X, A, B) - DTYPE: int + DTYPE: int32 - NAME: cos OPERATOR: cos(X) - NAME: exp diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.yaml b/backends/vulkan/runtime/graph/ops/glsl/view.yaml index ba11a2496a0..33364a25225 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/view.yaml @@ -7,6 +7,6 @@ view: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: view diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp new file mode 100644 index 00000000000..1dc2d34afbf --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -0,0 +1,347 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +namespace vkcompute { + +namespace { + +void resize_choose_qparams_tensor_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef scale_out = args.at(0).refs.at(0); + const ValueRef zero_point_out = args.at(0).refs.at(1); + + // Both scale and zero_point are scalar tensors for per-tensor quantization + // Since we use single workgroup approach, no extra buffer space needed + graph->virtual_resize(scale_out, {}); + graph->virtual_resize(zero_point_out, {}); +} + +void resize_choose_qparams_per_token_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef scale_out = args.at(0).refs.at(0); + const ValueRef zero_point_out = args.at(0).refs.at(1); + const ValueRef input = args.at(1).refs.at(0); + + // Calculate output sizes for scale and zero_point tensors + const auto input_sizes = graph->sizes_of(input); + std::vector output_sizes; + output_sizes.reserve(input_sizes.size() - 1); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + output_sizes.push_back(input_sizes[i]); + } + output_sizes.push_back(1); + + graph->virtual_resize(scale_out, output_sizes); + graph->virtual_resize(zero_point_out, output_sizes); +} + +// Custom workgroup size pickers for ChooseQParams operations +utils::uvec3 choose_qparams_pick_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + // For per-tensor quantization, we want a single workgroup that can handle + // all elements with proper reduction. The shader uses NWORKERS=64 threads. + const ValueRef input = args.at(1).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For buffer storage, use a single workgroup in X dimension + // The shader will handle strided access across all elements + return {1u, 1u, 1u}; + } else { + // For texture storage, use the default logic + return graph->create_global_wg_size(args.at(0).refs.at(0)); + } +} + +utils::uvec3 choose_qparams_pick_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For buffer storage, use 64 threads in X dimension to match NWORKERS + // This ensures the shared memory arrays are properly sized + return {64u, 1u, 1u}; + } else { + // For texture storage, use the default logic + return graph->create_local_wg_size(global_workgroup_size); + } +} + +utils::uvec3 choose_qparams_per_token_pick_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For per-token quantization, we need one workgroup per token + // Calculate number of tokens (product of all dimensions except the last + // one) + const auto input_sizes = graph->sizes_of(input); + int64_t num_tokens = 1; + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + return {static_cast(num_tokens), 1u, 1u}; + } else { + // For texture storage, use the default logic + return graph->create_global_wg_size(args.at(0).refs.at(0)); + } +} + +utils::uvec3 choose_qparams_per_token_pick_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + + if (graph->is_buffer_storage(input)) { + // For buffer storage, use 64 threads in X dimension to match NWORKERS + return {64u, 1u, 1u}; + } else { + // For texture storage, use the default logic + return graph->create_local_wg_size(global_workgroup_size); + } +} + +} // namespace + +void add_choose_qparams_tensor_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& scale_out, + const ValueRef& zero_point_out) { + std::string kernel_name("choose_qparams_tensor"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(scale_out), + graph.strides_ubo(scale_out), + graph.sizes_ubo(zero_point_out), + graph.strides_ubo(zero_point_out)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(scale_out), + graph.logical_limits_ubo(zero_point_out)}; + } + + std::vector push_constants; + push_constants = { + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + choose_qparams_pick_global_wg_size, + choose_qparams_pick_local_wg_size, + // Inputs and Outputs + {{scale_out, vkapi::kWrite}, + {zero_point_out, vkapi::kWrite}, + {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_choose_qparams_tensor_output)); +} + +void add_choose_qparams_per_token_asymmetric_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale_out, + const ValueRef& zero_point_out) { + std::string kernel_name("choose_qparams_per_token_asymmetric"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + + // Calculate number of tokens (product of all dimensions except the last one) + int64_t num_tokens = 1; + const auto input_sizes = graph.sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + int num_tokens_val = static_cast(num_tokens); + int quant_min_val = -128; // Fixed for asymmetric quantization + int quant_max_val = 127; // Fixed for asymmetric quantization + + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(scale_out), + graph.strides_ubo(scale_out), + graph.sizes_ubo(zero_point_out), + graph.strides_ubo(zero_point_out)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(scale_out), + graph.logical_limits_ubo(zero_point_out)}; + } + + std::vector push_constants; + push_constants = { + PushConstantDataInfo(&num_tokens_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + choose_qparams_per_token_pick_global_wg_size, + choose_qparams_per_token_pick_local_wg_size, + // Inputs and Outputs + {{scale_out, vkapi::kWrite}, + {zero_point_out, vkapi::kWrite}, + {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_choose_qparams_per_token_output)); +} + +void choose_qparams_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef scale_out = args[arg_idx++]; + const ValueRef zero_point_out = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale_out)); + VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf || + graph.dtype_of(input) == vkapi::kDouble); + + // Verify output types - accept CPU types but convert to GPU types + VK_CHECK_COND( + graph.dtype_of(scale_out) == vkapi::kFloat || + graph.dtype_of(scale_out) == vkapi::kDouble); + VK_CHECK_COND( + graph.dtype_of(zero_point_out) == vkapi::kInt || + graph.dtype_of(zero_point_out) == vkapi::kLong); + + // Check that texture storage is width packed + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + } + + add_choose_qparams_tensor_node( + graph, input, quant_min, quant_max, scale_out, zero_point_out); +} + +void choose_qparams_per_token_asymmetric_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale_out = args[arg_idx++]; + const ValueRef zero_point_out = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale_out)); + VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf || + graph.dtype_of(input) == vkapi::kDouble); + + // Verify output types - accept CPU types but convert to GPU types + VK_CHECK_COND( + graph.dtype_of(scale_out) == vkapi::kFloat || + graph.dtype_of(scale_out) == vkapi::kDouble); + VK_CHECK_COND( + graph.dtype_of(zero_point_out) == vkapi::kInt || + graph.dtype_of(zero_point_out) == vkapi::kLong); + + add_choose_qparams_per_token_asymmetric_node( + graph, input, scale_out, zero_point_out); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(choose_qparams.tensor, choose_qparams_tensor_impl); + VK_REGISTER_OP( + choose_qparams_per_token_asymmetric.default, + choose_qparams_per_token_asymmetric_impl); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp new file mode 100644 index 00000000000..77a51ce24f9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -0,0 +1,274 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include + +namespace vkcompute { + +namespace { + +void resize_dequantize_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + graph->virtual_resize(out, graph->sizes_of(in)); +} + +} // namespace + +void add_dequantize_per_tensor_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_per_tensor"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + float scale_val = static_cast(graph.get_double(scale)); + int zero_point_val = static_cast(graph.get_int(zero_point)); + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output)}; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_output)); +} + +void add_dequantize_per_token_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_per_token"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + int num_tokens = static_cast(graph.sizes_of(scale)[0]); + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_output)); +} + +void dequantize_per_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is an integer type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kByte || + graph.dtype_of(input) == vkapi::kChar || + graph.dtype_of(input) == vkapi::kShort || + graph.dtype_of(input) == vkapi::kInt); + + // Verify output is a floating point type + VK_CHECK_COND( + graph.dtype_of(output) == vkapi::kHalf || + graph.dtype_of(output) == vkapi::kFloat || + graph.dtype_of(output) == vkapi::kDouble); + + add_dequantize_per_tensor_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +void dequantize_per_token_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is an integer type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kByte || + graph.dtype_of(input) == vkapi::kChar || + graph.dtype_of(input) == vkapi::kShort || + graph.dtype_of(input) == vkapi::kInt); + + // Verify output is a floating point type + VK_CHECK_COND( + graph.dtype_of(output) == vkapi::kHalf || + graph.dtype_of(output) == vkapi::kFloat || + graph.dtype_of(output) == vkapi::kDouble); + + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } + + // Calculate number of tokens (product of all dimensions except the last one) + int64_t num_tokens = 1; + const auto input_sizes = graph.sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + const auto scale_sizes = graph.sizes_of(scale); + const auto zero_point_sizes = graph.sizes_of(zero_point); + + VK_CHECK_COND(scale_sizes.size() == 1); + VK_CHECK_COND(zero_point_sizes.size() == 1); + VK_CHECK_COND(scale_sizes[0] == num_tokens); + VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + + add_dequantize_per_token_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(dequantize_per_tensor.default, dequantize_per_tensor_impl); + VK_REGISTER_OP(dequantize_per_token.default, dequantize_per_token_impl); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp new file mode 100644 index 00000000000..49277b4d718 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -0,0 +1,260 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include + +namespace vkcompute { + +namespace { + +void resize_quantize_output( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + graph->virtual_resize(out, graph->sizes_of(in)); +} + +} // namespace + +void add_quantize_per_tensor_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("quantize_per_tensor"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + float scale_val = static_cast(graph.get_double(scale)); + int zero_point_val = static_cast(graph.get_int(zero_point)); + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output)}; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + push_constants = { + PushConstantDataInfo(&scale_val, sizeof(float)), + PushConstantDataInfo(&zero_point_val, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_quantize_output)); +} + +void add_quantize_per_token_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("quantize_per_token"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + int num_tokens = static_cast(graph.sizes_of(scale)[0]); + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } else { + param_ubos = { + graph.logical_limits_ubo(input), + graph.logical_limits_ubo(output), + }; + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + } + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_quantize_output)); +} + +void quantize_per_tensor_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kDouble || + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf); + + add_quantize_per_tensor_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +void quantize_per_token_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kDouble || + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf); + + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } + + // Calculate number of tokens (product of all dimensions except the last one) + int64_t num_tokens = 1; + const auto input_sizes = graph.sizes_of(input); + for (size_t i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + const auto scale_sizes = graph.sizes_of(scale); + const auto zero_point_sizes = graph.sizes_of(zero_point); + + VK_CHECK_COND(scale_sizes.size() == 1); + VK_CHECK_COND(zero_point_sizes.size() == 1); + VK_CHECK_COND(scale_sizes[0] == num_tokens); + VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + + add_quantize_per_token_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(quantize_per_tensor.default, quantize_per_tensor_impl); + VK_REGISTER_OP(quantize_per_token.default, quantize_per_token_impl); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp index e1ac4e9d40a..6388a8ad091 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -34,24 +34,42 @@ void add_storage_type_suffix( void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) { switch (dtype) { + case vkapi::kDouble: + kernel_name += "_double"; + break; case vkapi::kFloat: kernel_name += "_float"; break; case vkapi::kHalf: kernel_name += "_half"; break; - case vkapi::kInt: - kernel_name += "_int"; - break; case vkapi::kChar: case vkapi::kQInt8: kernel_name += "_int8"; break; case vkapi::kByte: - case vkapi::kQUInt8: case vkapi::kBool: + case vkapi::kQUInt8: kernel_name += "_uint8"; break; + case vkapi::kShort: + kernel_name += "_int16"; + break; + case vkapi::kUInt16: + kernel_name += "_uint16"; + break; + case vkapi::kInt: + kernel_name += "_int32"; + break; + case vkapi::kUInt: + kernel_name += "_uint32"; + break; + case vkapi::kLong: + kernel_name += "_int64"; + break; + case vkapi::kUInt64: + kernel_name += "_uint64"; + break; default: break; } diff --git a/backends/vulkan/runtime/vk_api/Types.h b/backends/vulkan/runtime/vk_api/Types.h index f25fe95d72b..b3309aa6c69 100644 --- a/backends/vulkan/runtime/vk_api/Types.h +++ b/backends/vulkan/runtime/vk_api/Types.h @@ -30,11 +30,17 @@ #define VK_FORALL_SCALAR_TYPES(_) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \ - _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ - _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Bool) \ + _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ _(uint16_t, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \ + _(uint16_t, VK_FORMAT_R16G16B16A16_UINT, UInt16) \ + _(int16_t, VK_FORMAT_R16G16B16A16_SINT, Short) \ + _(uint32_t, VK_FORMAT_R32G32B32A32_UINT, UInt) \ + _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ + _(uint64_t, VK_FORMAT_R64G64B64A64_UINT, UInt64) \ + _(int64_t, VK_FORMAT_R64G64B64A64_SINT, Long) \ _(float, VK_FORMAT_FLOAT4, Float) \ + _(double, VK_FORMAT_R64G64B64A64_SFLOAT, Double) \ _(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \ _(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32) @@ -86,17 +92,29 @@ inline VkFormat to_vkformat(const ScalarType t) { */ inline ScalarType element_scalartype(const VkFormat vkformat) { switch (vkformat) { + case VK_FORMAT_R64G64B64A64_SFLOAT: + return kDouble; + case VK_FORMAT_R32G32B32A32_SFLOAT: + return kFloat; + case VK_FORMAT_R16G16B16A16_SFLOAT: + return kHalf; case VK_FORMAT_R8G8B8A8_SINT: return kChar; case VK_FORMAT_R8G8B8A8_UINT: case VK_FORMAT_R8G8B8A8_UNORM: return kByte; + case VK_FORMAT_R16G16B16A16_SINT: + return kShort; + case VK_FORMAT_R16G16B16A16_UINT: + return kUInt16; case VK_FORMAT_R32G32B32A32_SINT: return kInt; - case VK_FORMAT_R32G32B32A32_SFLOAT: - return kFloat; - case VK_FORMAT_R16G16B16A16_SFLOAT: - return kHalf; + case VK_FORMAT_R32G32B32A32_UINT: + return kUInt; + case VK_FORMAT_R64G64B64A64_SINT: + return kLong; + case VK_FORMAT_R64G64B64A64_UINT: + return kUInt64; default: VK_THROW("No corresponding scalar type for unknown VkFormat: ", vkformat); } diff --git a/backends/vulkan/test/glsl/all_shaders.yaml b/backends/vulkan/test/glsl/all_shaders.yaml index 37403c97ac8..4ef934eb105 100644 --- a/backends/vulkan/test/glsl/all_shaders.yaml +++ b/backends/vulkan/test/glsl/all_shaders.yaml @@ -51,7 +51,7 @@ idx_fill_texture: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 shader_variants: - NAME: idx_fill_texture diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp new file mode 100644 index 00000000000..55e96151387 --- /dev/null +++ b/backends/vulkan/test/op_tests/choose_qparams_test.cpp @@ -0,0 +1,771 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include +#include + +#include "test_utils.h" + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +// Forward declarations of the functions we're testing +std::tuple choose_qparams_tensor_out( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ET_UNUSED double eps, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out); + +std::tuple choose_qparams_per_token_asymmetric_out( + const Tensor& input, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out); + +// Wrapper function for choose_qparams_tensor_out without context +Tensor& choose_qparams_tensor_out_no_context( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ET_UNUSED double eps, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out) { + torch::executor::native::choose_qparams_tensor_out( + input, quant_min, quant_max, eps, dtype, scale_out, zero_point_out); + return scale_out; +} + +// Wrapper function for choose_qparams_per_token_asymmetric_out without context +Tensor& choose_qparams_per_token_asymmetric_out_no_context( + const Tensor& input, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out) { + torch::executor::native::choose_qparams_per_token_asymmetric_out( + input, dtype, scale_out, zero_point_out); + return scale_out; +} + +// ATen wrapper for choose_qparams_tensor +std::tuple choose_qparams_tensor_aten( + const at::Tensor& input, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble)); + auto zero_point_out = at::empty({}, at::device(at::kCPU).dtype(at::kLong)); + double eps = 1e-7; + + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + // Use WRAP_TO_ATEN with the wrapper function + WRAP_TO_ATEN(choose_qparams_tensor_out_no_context, 5) + (input, quant_min, quant_max, eps, et_dtype, scale_out, zero_point_out); + + return {scale_out, zero_point_out}; +} + +// ATen wrapper for choose_qparams_per_token_asymmetric +std::tuple choose_qparams_per_token_asymmetric_aten( + const at::Tensor& input, + at::ScalarType dtype) { + // Calculate output sizes for scale and zero_point tensors + std::vector output_sizes; + for (int64_t i = 0; i < input.dim() - 1; i++) { + output_sizes.push_back(input.size(i)); + } + output_sizes.push_back(1); + + auto scale_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble)); + auto zero_point_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong)); + + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + // Use WRAP_TO_ATEN with the wrapper function + WRAP_TO_ATEN(choose_qparams_per_token_asymmetric_out_no_context, 2) + (input, et_dtype, scale_out, zero_point_out); + + return {scale_out, zero_point_out}; +} + +} // namespace native +} // namespace executor +} // namespace torch + +// +// Reference Implementation +// + +/* + * Reference implementation of choose_qparams_tensor + */ +std::tuple choose_qparams_tensor_reference_impl( + const at::Tensor& input, + int64_t quant_min, + int64_t quant_max) { + // Create output tensors + at::Tensor scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_out = + at::empty({}, at::device(at::kCPU).dtype(at::kLong)); + + // Find min and max values in the input tensor + float min_val = input.min().item(); + float max_val = input.max().item(); + + // Extend the [min, max] interval to ensure it contains 0 + min_val = std::min(min_val, 0.f); + max_val = std::max(max_val, 0.f); + + // Calculate scale + double scale = + (static_cast(max_val) - min_val) / (quant_max - quant_min); + + // Handle small scale + constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; + if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { + scale = 0.1; + } + + if (scale < SMALL_SCALE_THRESHOLD) { + float org_scale = scale; + scale = SMALL_SCALE_THRESHOLD; + // Adjust min and max based on new scale + if (min_val == 0.0f) { + max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min); + } else if (max_val == 0.0f) { + min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min_val *= amplifier; + max_val *= amplifier; + } + } + + // Calculate zero point + double zero_point_from_min = quant_min - min_val / static_cast(scale); + double zero_point_from_max = quant_max - max_val / static_cast(scale); + double zero_point_from_min_error = + std::abs(quant_min) - std::abs(min_val / static_cast(scale)); + double zero_point_from_max_error = + std::abs(quant_max) - std::abs(max_val / static_cast(scale)); + double initial_zero_point = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Nudge zero point to be an integer + int64_t nudged_zero_point = 0; + if (initial_zero_point < quant_min) { + nudged_zero_point = quant_min; + } else if (initial_zero_point > quant_max) { + nudged_zero_point = quant_max; + } else { + nudged_zero_point = std::nearbyint(static_cast(initial_zero_point)); + } + + // Set output values - use item_mutable() for scalar tensors + scale_out.fill_(scale); + zero_point_out.fill_(nudged_zero_point); + + return std::make_tuple(scale_out, zero_point_out); +} + +/* + * Reference implementation of choose_qparams_per_token_asymmetric + */ +std::tuple +choose_qparams_per_token_asymmetric_reference_impl( + const at::Tensor& input, + at::ScalarType dtype) { + // For per-token quantization, we need to compute scale and zero_point for + // each token + int64_t quant_min = -128; + int64_t quant_max = 127; + + // Calculate output sizes + std::vector output_sizes; + for (int64_t i = 0; i < input.dim() - 1; i++) { + output_sizes.push_back(input.size(i)); + } + output_sizes.push_back(1); + + // Create output tensors + at::Tensor scale_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong)); + + // Calculate number of tokens + int64_t num_tokens = 1; + for (int64_t i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + + // Reshape input to [num_tokens, last_dim] + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + + // Process each token + for (int64_t token_idx = 0; token_idx < num_tokens; token_idx++) { + at::Tensor token = reshaped_input[token_idx]; + + // Find min and max values for this token + float min_val = token.min().item(); + float max_val = token.max().item(); + + // Extend the [min, max] interval to ensure it contains 0 + min_val = std::min(min_val, 0.f); + max_val = std::max(max_val, 0.f); + + // Calculate scale + double scale = + (static_cast(max_val) - min_val) / (quant_max - quant_min); + + // Handle small scale + constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; + if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { + scale = 0.1; + } + + if (scale < SMALL_SCALE_THRESHOLD) { + float org_scale = scale; + scale = SMALL_SCALE_THRESHOLD; + // Adjust min and max based on new scale + if (min_val == 0.0f) { + max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min); + } else if (max_val == 0.0f) { + min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min_val *= amplifier; + max_val *= amplifier; + } + } + + // Calculate zero point + double zero_point_from_min = + quant_min - min_val / static_cast(scale); + double zero_point_from_max = + quant_max - max_val / static_cast(scale); + double zero_point_from_min_error = + std::abs(quant_min) - std::abs(min_val / static_cast(scale)); + double zero_point_from_max_error = + std::abs(quant_max) - std::abs(max_val / static_cast(scale)); + double initial_zero_point = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Nudge zero point to be an integer + int64_t nudged_zero_point = 0; + if (initial_zero_point < quant_min) { + nudged_zero_point = quant_min; + } else if (initial_zero_point > quant_max) { + nudged_zero_point = quant_max; + } else { + nudged_zero_point = + std::nearbyint(static_cast(initial_zero_point)); + } + + // Set output values for this token - use index_put_ for safety + scale_out.view({num_tokens, 1}).index_put_({token_idx, 0}, scale); + zero_point_out.view({num_tokens, 1}) + .index_put_({token_idx, 0}, nudged_zero_point); + } + + return std::make_tuple(scale_out, zero_point_out); +} + +// Forward declaration of implementation functions +void test_vulkan_choose_qparams_tensor_impl( + const std::vector& input_sizes, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +void test_vulkan_choose_qparams_per_token_asymmetric_impl( + const std::vector& input_sizes, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_choose_qparams_tensor( + const std::vector& input_sizes, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + // Test with buffer storage + test_vulkan_choose_qparams_tensor_impl( + input_sizes, + quant_min, + quant_max, + dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_choose_qparams_tensor_impl( + input_sizes, + quant_min, + quant_max, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_choose_qparams_per_token_asymmetric( + const std::vector& input_sizes, + at::ScalarType dtype) { + // Test with buffer storage + test_vulkan_choose_qparams_per_token_asymmetric_impl( + input_sizes, dtype, vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_choose_qparams_per_token_asymmetric_impl( + input_sizes, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +void test_reference_choose_qparams_tensor( + const std::vector& input_sizes, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); + + // Get reference output + auto [reference_scale, reference_zero_point] = + choose_qparams_tensor_reference_impl(input, quant_min, quant_max); + + // Get implementation output + auto [impl_scale, impl_zero_point] = + torch::executor::native::choose_qparams_tensor_aten( + input, quant_min, quant_max, dtype); + + // Compare outputs + const bool scale_correct = at::allclose(reference_scale, impl_scale); + const bool zero_point_correct = + at::equal(reference_zero_point, impl_zero_point); + + if (!scale_correct || !zero_point_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference scale:" << std::endl; + std::cout << reference_scale << std::endl; + std::cout << "implementation scale:" << std::endl; + std::cout << impl_scale << std::endl; + std::cout << "reference zero_point:" << std::endl; + std::cout << reference_zero_point << std::endl; + std::cout << "implementation zero_point:" << std::endl; + std::cout << impl_zero_point << std::endl; + } + + ASSERT_TRUE(scale_correct && zero_point_correct); +} + +void test_vulkan_choose_qparams_tensor_impl( + const std::vector& input_sizes, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); + + // Get reference output + auto [reference_scale, reference_zero_point] = + torch::executor::native::choose_qparams_tensor_aten( + input, quant_min, quant_max, dtype); + + // Build Vulkan choose_qparams_tensor graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + // Output tensors + const ValueRef r_scale = graph.add_tensor({}, vkapi::kFloat, out_storage); + const ValueRef r_zero_point = graph.add_tensor({}, vkapi::kInt, out_storage); + + VK_GET_OP_FN("choose_qparams.tensor") + (graph, + { + r_input.value, + r_quant_min, + r_quant_max, + r_scale, + r_zero_point, + }); + + ValueRef staging_scale = graph.set_output_tensor(r_scale); + ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan choose_qparams_tensor + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + graph.execute(); + + // Create output tensors to hold the results - use types that match GPU output + at::Tensor vk_scale = + at::empty({}, at::device(at::kCPU).dtype(at::kFloat)).contiguous(); + at::Tensor vk_zero_point = + at::empty({}, at::device(at::kCPU).dtype(at::kInt)).contiguous(); + + // Copy results from GPU to CPU + graph.copy_from_staging( + staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); + graph.copy_from_staging( + staging_zero_point, + vk_zero_point.mutable_data_ptr(), + vk_zero_point.numel()); + + // Convert reference values to match Vulkan output types for comparison + at::Tensor reference_scale_float = reference_scale.to(at::kFloat); + at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt); + + // Compare outputs + const bool scale_correct = at::allclose(reference_scale_float, vk_scale); + const bool zero_point_correct = + at::equal(reference_zero_point_int, vk_zero_point); + + if (!scale_correct || !zero_point_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + // make sure that there arent a ton of elements in the input tensor + if (input.numel() < 100) { + std::cout << "input:" << std::endl; + std::cout << input << "\n" << std::endl; + std::cout << "reference scale:" << std::endl; + std::cout << reference_scale << std::endl; + std::cout << "vulkan scale:" << std::endl; + std::cout << vk_scale << "\n" << std::endl; + std::cout << "reference zero_point:" << std::endl; + std::cout << reference_zero_point << std::endl; + std::cout << "vulkan zero_point:" << std::endl; + std::cout << vk_zero_point << std::endl; + } + } + + ASSERT_TRUE(scale_correct && zero_point_correct); +} + +TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) { + test_reference_choose_qparams_tensor( + {2, 3, 4}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_uint8_4D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_tensor( + {5, 3, 2, 4}, // input sizes + 0, // quant_min + 255, // quant_max + at::kByte); +} + +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_2D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_tensor( + {5, 5}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_3D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_tensor( + {12, 8, 2}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + +TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_4D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_tensor( + {10, 10, 6, 4}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + +void test_reference_choose_qparams_per_token_asymmetric( + const std::vector& input_sizes, + at::ScalarType dtype) { + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); + + // Get reference output + auto [reference_scale, reference_zero_point] = + choose_qparams_per_token_asymmetric_reference_impl(input, dtype); + + // Get implementation output + auto [impl_scale, impl_zero_point] = + torch::executor::native::choose_qparams_per_token_asymmetric_aten( + input, dtype); + + // Compare outputs + const bool scale_correct = at::allclose(reference_scale, impl_scale); + const bool zero_point_correct = + at::equal(reference_zero_point, impl_zero_point); + + if (!scale_correct || !zero_point_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference scale:" << std::endl; + std::cout << reference_scale << std::endl; + std::cout << "implementation scale:" << std::endl; + std::cout << impl_scale << std::endl; + std::cout << "reference zero_point:" << std::endl; + std::cout << reference_zero_point << std::endl; + std::cout << "implementation zero_point:" << std::endl; + std::cout << impl_zero_point << std::endl; + } + + ASSERT_TRUE(scale_correct && zero_point_correct); +} + +void test_vulkan_choose_qparams_per_token_asymmetric_impl( + const std::vector& input_sizes, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); + + // Calculate output sizes + std::vector output_sizes; + for (int64_t i = 0; i < input.dim() - 1; i++) { + output_sizes.push_back(input.size(i)); + } + output_sizes.push_back(1); + + // Get reference output + auto [reference_scale, reference_zero_point] = + torch::executor::native::choose_qparams_per_token_asymmetric_aten( + input, dtype); + + // Build Vulkan choose_qparams_per_token_asymmetric graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + // Output tensors + const ValueRef r_scale = + graph.add_tensor(output_sizes, vkapi::kFloat, out_storage); + const ValueRef r_zero_point = + graph.add_tensor(output_sizes, vkapi::kInt, out_storage); + + VK_GET_OP_FN("choose_qparams_per_token_asymmetric.default") + (graph, + { + r_input.value, + r_scale, + r_zero_point, + }); + + ValueRef staging_scale = graph.set_output_tensor(r_scale); + ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan choose_qparams_per_token_asymmetric + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + graph.execute(); + + // Create output tensors to hold the results - use types that match GPU output + at::Tensor vk_scale = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kFloat)) + .contiguous(); + at::Tensor vk_zero_point = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kInt)) + .contiguous(); + + // Copy results from GPU to CPU + graph.copy_from_staging( + staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); + graph.copy_from_staging( + staging_zero_point, + vk_zero_point.mutable_data_ptr(), + vk_zero_point.numel()); + + // Convert reference values to match Vulkan output types for comparison + at::Tensor reference_scale_float = reference_scale.to(at::kFloat); + at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt); + + // Compare outputs + const bool scale_correct = at::allclose(reference_scale_float, vk_scale); + const bool zero_point_correct = + at::equal(reference_zero_point_int, vk_zero_point); + if (!scale_correct || !zero_point_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + if (input.numel() < 100) { + std::cout << "input:" << std::endl; + std::cout << input << "\n" << std::endl; + std::cout << "reference scale:" << std::endl; + std::cout << reference_scale << std::endl; + std::cout << "vulkan scale:" << std::endl; + std::cout << vk_scale << "\n" << std::endl; + std::cout << "reference zero_point:" << std::endl; + std::cout << reference_zero_point << std::endl; + std::cout << "vulkan zero_point:" << std::endl; + std::cout << vk_zero_point << std::endl; + } + } + + ASSERT_TRUE(scale_correct && zero_point_correct); +} + +TEST( + VulkanChooseQparamsTest, + test_reference_choose_qparams_per_token_asymmetric_int8) { + test_reference_choose_qparams_per_token_asymmetric( + {2, 3, 4}, // input sizes (2*3=6 tokens) + at::kChar); +} + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_1D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_per_token_asymmetric({7}, at::kChar); +} + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_2D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_per_token_asymmetric({2, 2}, at::kChar); +} + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_3D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_per_token_asymmetric({3, 6, 4}, at::kChar); +} + +TEST( + VulkanChooseQparamsTest, + test_vulkan_choose_qparams_per_token_asymmetric_int8_4D) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_choose_qparams_per_token_asymmetric({128, 2, 16, 3}, at::kChar); +} diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp new file mode 100644 index 00000000000..6c604076c41 --- /dev/null +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -0,0 +1,1341 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include +#include + +#include "test_utils.h" + +#include +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +// Forward declarations of the functions we're testing +Tensor& dequantize_per_tensor_out( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out); + +Tensor& dequantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out); + +// Wrapper function for dequantize_per_tensor_out without context +Tensor& dequantize_per_tensor_out_no_context( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); +} + +// Wrapper function for dequantize_per_token_out without context +Tensor& dequantize_per_token_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_token_out( + input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); +} + +// ATen wrapper for dequantize_per_tensor +at::Tensor dequantize_per_tensor_aten( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); + + executorch::aten::optional opt_et_out_dtype(et_out_dtype); + + WRAP_TO_ATEN(dequantize_per_tensor_out_no_context, 7) + (input, + scale, + zero_point, + quant_min, + quant_max, + et_dtype, + opt_et_out_dtype, + out); + return out; +} + +// ATen wrapper for dequantize_per_token +at::Tensor dequantize_per_token_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); + + WRAP_TO_ATEN(dequantize_per_token_out_no_context, 7) + (input, + scale, + zero_points, + quant_min, + quant_max, + et_dtype, + et_out_dtype, + out); + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch + +void check_dequantize_args( + int64_t quant_min, + int64_t quant_max, + c10::ScalarType in_dtype, + c10::ScalarType out_dtype) { + using namespace vkcompute; + + // Check that quant_min <= quant_max + VK_CHECK_COND( + quant_min <= quant_max, + "quant_min must be <= quant_max, got quant_min: ", + quant_min, + " quant_max: ", + quant_max); + + // Check that input dtype is a quantized type + switch (in_dtype) { + case c10::kByte: + case c10::kChar: + case c10::kShort: + case c10::kInt: + case c10::kLong: + break; + default: + VK_THROW( + "Unsupported input dtype: ", + scalar_type_name(in_dtype), + " (", + static_cast(in_dtype), + ")"); + } + + // Check that output dtype is a floating point type + switch (out_dtype) { + case c10::kHalf: + case c10::kFloat: + case c10::kDouble: + break; + default: + VK_THROW( + "Unsupported output dtype: ", + scalar_type_name(out_dtype), + " (", + static_cast(out_dtype), + ")"); + } +} + +// +// Reference Implementation +// + +/* + * Reference implementation of dequantize_per_tensor + */ +at::Tensor dequantize_per_tensor_reference_impl( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Create output tensor with the target dtype + at::Tensor out = at::empty_like(input, out_dtype); + + // Dequantize the input tensor + at::Tensor flat_input = input.flatten(); + at::Tensor flat_out = out.flatten(); + + // Store casted values to avoid repeated casting + const int32_t zero_point_int32 = static_cast(zero_point); + const float scale_float = static_cast(scale); + + for (int i = 0; i < flat_input.numel(); i++) { + double dequantized_value = 0.0; + + // Extract quantized value and dequantize based on input dtype + // Following the CPU implementation pattern: (input - zero_point) * scale + if (dtype == at::kByte) { + uint8_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kChar) { + int8_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kShort) { + int16_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kInt) { + int32_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kLong) { + int64_t qvalue = flat_input[i].item(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } + + // Store result based on output dtype + if (out_dtype == at::kFloat) { + flat_out[i] = static_cast(dequantized_value); + } else if (out_dtype == at::kDouble) { + flat_out[i] = dequantized_value; + } else if (out_dtype == at::kHalf) { + flat_out[i] = static_cast(dequantized_value); + } + } + + return out.reshape(input.sizes()); +} + +/* + * Reference implementation of dequantize_per_token + */ +at::Tensor dequantize_per_token_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Create output tensor with the target dtype + at::Tensor out = at::empty_like(input, out_dtype); + + // Calculate number of tokens + int num_tokens = 1; + for (int i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + + // Verify that the number of tokens matches the size of scale and zero_point + // tensors + assert(num_tokens == scale.numel()); + assert(num_tokens == zero_point.numel()); + + // Reshape input to [num_tokens, last_dim] + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)}); + + // Dequantize each token separately + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + // Get scale and zero_point for this token + float token_scale = scale[token_idx].item(); + int64_t token_zero_point = zero_point[token_idx].item(); + + // Store casted values to avoid repeated casting + const int32_t token_zero_point_int32 = + static_cast(token_zero_point); + + // Dequantize the token + for (int i = 0; i < input.size(-1); i++) { + double dequantized_value = 0.0; + + // Extract quantized value and dequantize based on input dtype + // Following the CPU implementation pattern: (input - zero_point) * scale + if (dtype == at::kByte) { + uint8_t qvalue = reshaped_input[token_idx][i].item(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else if (dtype == at::kChar) { + int8_t qvalue = reshaped_input[token_idx][i].item(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else if (dtype == at::kShort) { + int16_t qvalue = reshaped_input[token_idx][i].item(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else if (dtype == at::kInt) { + int32_t qvalue = reshaped_input[token_idx][i].item(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else if (dtype == at::kLong) { + int64_t qvalue = reshaped_input[token_idx][i].item(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else { + throw std::runtime_error("Unsupported input dtype"); + } + + // Store result based on output dtype + if (out_dtype == at::kFloat) { + reshaped_out[token_idx][i] = static_cast(dequantized_value); + } else if (out_dtype == at::kDouble) { + reshaped_out[token_idx][i] = dequantized_value; + } else if (out_dtype == at::kHalf) { + reshaped_out[token_idx][i] = static_cast(dequantized_value); + } + } + } + + return out; +} + +// Forward declaration of implementation functions +void test_vulkan_dequantize_per_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +void test_vulkan_dequantize_per_token_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_per_tensor( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Test with buffer storage + test_vulkan_dequantize_per_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Telling the system to expect a float instead of a double + // since the shader can only return 32bit anyways + if (out_dtype == at::kDouble) { + out_dtype = at::kFloat; + } + + // Test with texture storage + test_vulkan_dequantize_per_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_per_token( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Test with buffer storage + test_vulkan_dequantize_per_token_impl( + input_sizes, + scales, + zero_points, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Telling the system to expect a float instead of a double + // since the shader can only return 32bit anyways + if (out_dtype == at::kDouble) { + out_dtype = at::kFloat; + } + + // Test with texture storage + test_vulkan_dequantize_per_token_impl( + input_sizes, + scales, + zero_points, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +void test_reference_dequantize_per_tensor( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create a quantized input tensor with values from quant_min to quant_max + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + float step = 1.0f; + if (input.numel() > 1) { + step = static_cast(quant_max - quant_min) / (input.numel() - 1); + } + + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + flat_input[i] = static_cast(qvalue); + } + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Get reference output + at::Tensor reference_out = dequantize_per_tensor_reference_impl( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); + + // Get implementation output + at::Tensor impl_out = torch::executor::native::dequantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); + + // Compare outputs + const bool output_correct = at::allclose(reference_out, impl_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "implementation:" << std::endl; + std::cout << impl_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_dequantize_per_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create a quantized input tensor with values from quant_min to quant_max + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + float step = 1.0f; + if (input.numel() > 1) { + step = static_cast(quant_max - quant_min) / (input.numel() - 1); + } + + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + flat_input[i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + flat_input[i] = static_cast(qvalue); + } + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Get reference output + at::Tensor reference_out = + torch::executor::native::dequantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); + + // Build Vulkan dequantize_per_tensor graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(dtype), in_storage); + + const ValueRef r_scale = graph.add_scalar(scale); + const ValueRef r_zero_point = graph.add_scalar(zero_point); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); + + VK_GET_OP_FN("dequantize_per_tensor.default") + (graph, + { + r_input.value, + r_scale, + r_zero_point, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan dequantize_per_tensor + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs with appropriate tolerance for half precision + bool output_correct; + if (out_dtype == at::kHalf) { + // Use higher tolerance for half precision due to limited precision + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); + } else { + output_correct = at::allclose(reference_out, vk_out); + } + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_uint8_to_float) { + test_reference_dequantize_per_tensor( + {2, 3, 4}, // input sizes + 0.1, // scale + 5, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_int8_to_float) { + test_reference_dequantize_per_tensor( + {3, 4, 5}, // input sizes + 0.05, // scale + 0, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_int32_to_float) { + test_reference_dequantize_per_tensor( + {4, 6, 2}, // input sizes + 0.2, // scale + 2, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_uint8_to_half) { + test_reference_dequantize_per_tensor( + {7, 4}, // input sizes + 0.1, // scale + 10, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype (uint8) + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_int32_to_half) { + test_reference_dequantize_per_tensor( + {2, 6, 5}, // input sizes + 0.3, // scale + -10, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_uint8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {2, 3, 4}, // input sizes + 0.1, // scale + 5, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {3, 4}, // input sizes + 0.05, // scale + 0, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int32_to_float) { + test_vulkan_dequantize_per_tensor( + {2, 4, 3, 12}, // input sizes + 0.0001, // scale + 100, // zero_point + -2147483648, // quant_min + 2147483647, // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int8_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {2, 3}, // input sizes + 0.05, // scale + 10, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int32_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + // Use much smaller scale to avoid overflow to infinity in half precision + // Half precision max value is ~65504, so with int32 values around 2e9, + // we need scales smaller than 65504/2e9 ≈ 3e-5 to avoid overflow + test_vulkan_dequantize_per_tensor( + {7}, // input sizes + 1e-5, // scale (much smaller to avoid overflow) + 5, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int8_to_double) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {2, 3}, // input sizes + 0.05, // scale + 10, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kDouble); // output dtype +} + +void test_reference_dequantize_per_token( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + int num_tokens = 1; + for (int i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, zero_points.size()); + + // Create input tensor with quantized values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + float step = 1.0f; + if (input.size(-1) > 1) { + step = static_cast(quant_max - quant_min) / (input.size(-1) - 1); + } + + for (int i = 0; i < input.size(-1); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } + } + } + + // Reshape back to original dimensions + input = reshaped_input.reshape(input_sizes_int64); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor reference_out = dequantize_per_token_reference_impl( + input, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype, + out_dtype); + + // Get implementation output + at::Tensor impl_out = torch::executor::native::dequantize_per_token_aten( + input, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype, + out_dtype); + + // Compare outputs + const bool output_correct = at::allclose(reference_out, impl_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "implementation:" << std::endl; + std::cout << impl_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_dequantize_per_token_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + int num_tokens = 1; + for (int i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, zero_points.size()); + + // Create input tensor with quantized values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + float step = 1.0f; + if (input.size(-1) > 1) { + step = static_cast(quant_max - quant_min) / (input.size(-1) - 1); + } + + for (int i = 0; i < input.size(-1); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } + } + } + + // Reshape back to original dimensions + input = reshaped_input.reshape(input_sizes_int64); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor reference_out = torch::executor::native::dequantize_per_token_aten( + input, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype, + out_dtype); + + // Build Vulkan dequantize_per_token graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(dtype), in_storage); + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); + + VK_GET_OP_FN("dequantize_per_token.default") + (graph, + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Convert scale tensor to float and copy to GPU + at::Tensor scale_float = scale_tensor.to(at::kFloat); + graph.copy_into_staging( + r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); + + // Convert zero_point tensor to int and copy to GPU + at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); + graph.copy_into_staging( + r_zero_point.staging, + zero_point_int.const_data_ptr(), + zero_point_int.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs with appropriate tolerance for half precision + bool output_correct; + if (out_dtype == at::kHalf) { + // Use higher tolerance for half precision due to limited precision + output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); + } else { + output_correct = at::allclose(reference_out, vk_out); + } + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + std::cout << " input dtype: " << dtype << std::endl; + std::cout << " output dtype: " << out_dtype << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_uint8_to_float) { + std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; + std::vector zero_points = {5, 10, 15, 20, 25, 30}; + + test_reference_dequantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_int8_to_float) { + std::vector scales = {0.05, 0.1, 0.15, 0.2}; + std::vector zero_points = {0, -5, 5, 10}; + + test_reference_dequantize_per_token( + {2, 2, 5}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_int32_to_float) { + std::vector scales = {0.05, 0.1, 0.15, 0.2}; + std::vector zero_points = {0, -5, 5, 10}; + + test_reference_dequantize_per_token( + {2, 2, 10}, // input sizes (2*2=4 tokens) + scales, + zero_points, + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_int8_to_half) { + std::vector scales = {0.05, 0.1, 0.15, 0.2}; + std::vector zero_points = {0, -5, 5, 10}; + + test_reference_dequantize_per_token( + {4, 1, 5}, // input sizes (4*1=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype (int8) + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_int32_to_half) { + std::vector scales = {0.05, 0.1}; + std::vector zero_points = {0, -5}; + + test_reference_dequantize_per_token( + {2, 2}, // input sizes (2 tokens) + scales, + zero_points, + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_uint8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; + std::vector zero_points = {5, 10, 15, 20, 25, 30}; + + test_vulkan_dequantize_per_token( + {2, 3, 6}, // input sizes (2*3=6 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int8_to_float) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.05, 0.0}; + std::vector zero_points = {10, -5}; + + test_vulkan_dequantize_per_token( + {2, 2}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int32_to_float) { + std::vector scales = { + 0.0001, 0.0002, 0.0003, 0.0, 0.0011, 0.0102, 0.1003, 0.0}; + std::vector zero_points = {100, -100, 50, -50, 12, -6, 4, -24}; + + test_vulkan_dequantize_per_token( + {2, 2, 2, 12}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -2147483648, // quant_min + 2147483647, // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int8_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.05, 0.2}; + std::vector zero_points = {2, -5}; + + test_vulkan_dequantize_per_token( + {2, 2}, // input sizes (2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int32_to_half) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + // Use much smaller scales to avoid overflow to infinity in half precision + // Half precision max value is ~65504, so with int32 values around 2e9, + // we need scales smaller than 65504/2e9 ≈ 3e-5 to avoid overflow + std::vector scales = {1e-5, 2e-5, 1.5e-5}; + std::vector zero_points = {20, -15, 1}; + + test_vulkan_dequantize_per_token( + {3, 6}, // input sizes (3 tokens) + scales, + zero_points, + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int8_to_double) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.05, 0.001}; + std::vector zero_points = {10, -5}; + + test_vulkan_dequantize_per_token( + {2, 2}, // input sizes (2 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kDouble); // output dtype +} diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp index b95b7b3aa6d..e48042c4620 100644 --- a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp +++ b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp @@ -14,6 +14,8 @@ #include #include +#include "test_utils.h" + #include // @@ -201,26 +203,6 @@ void test_reference_linear_qcs4w( ASSERT_TRUE(at::allclose(out, out_ref)); } -vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { - using namespace vkcompute; - switch (at_scalartype) { - case c10::kFloat: - return vkapi::kFloat; - case c10::kHalf: - return vkapi::kHalf; - case c10::kInt: - return vkapi::kInt; - case c10::kLong: - return vkapi::kInt; - case c10::kChar: - return vkapi::kChar; - case c10::kByte: - return vkapi::kByte; - default: - VK_THROW("Unsupported at::ScalarType!"); - } -} - void test_vulkan_linear_qga4w_impl( const int B, const int M, diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp new file mode 100644 index 00000000000..150bda6989e --- /dev/null +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -0,0 +1,1128 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include +#include + +#include "test_utils.h" + +#include +#include +#include + +float eps = 1e-7; + +namespace torch { +namespace executor { +namespace native { + +// Forward declarations of the functions we're testing +Tensor& quantize_per_tensor_out( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out); + +Tensor& quantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out); + +// Wrapper function for quantize_per_tensor_out without context +Tensor& quantize_per_tensor_out_no_context( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + return torch::executor::native::quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, dtype, out); +} + +// Wrapper function for quantize_per_token_out without context +Tensor& quantize_per_token_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + return torch::executor::native::quantize_per_token_out( + input, scale, zero_point, quant_min, quant_max, dtype, out); +} + +// ATen wrapper for quantize_per_tensor +at::Tensor quantize_per_tensor_aten( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto out = at::empty_like(input, dtype); + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + WRAP_TO_ATEN(quantize_per_tensor_out_no_context, 6) + (input, scale, zero_point, quant_min, quant_max, et_dtype, out); + return out; +} + +// ATen wrapper for quantize_per_token +at::Tensor quantize_per_token_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto out = at::empty_like(input, dtype); + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + WRAP_TO_ATEN(quantize_per_token_out_no_context, 6) + (input, scale, zero_point, quant_min, quant_max, et_dtype, out); + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch + +void check_quantize_args( + int64_t quant_min, + int64_t quant_max, + c10::ScalarType out_dtype) { + using namespace vkcompute; + int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0; + switch (out_dtype) { + case c10::kByte: + quant_min_lower_bound = + static_cast(std::numeric_limits::min()); + quant_max_upper_bound = + static_cast(std::numeric_limits::max()); + break; + case c10::kChar: + quant_min_lower_bound = + static_cast(std::numeric_limits::min()); + quant_max_upper_bound = + static_cast(std::numeric_limits::max()); + break; + case c10::kBits16: + case c10::kUInt16: + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + break; + case c10::kShort: + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + break; + case c10::kInt: + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + break; + default: + VK_CHECK_COND(false, "Unsupported dtype: ", scalar_type_name(out_dtype)); + } + VK_CHECK_COND( + quant_min >= quant_min_lower_bound, + "quant_min out of bound for dtype, expected quant_min_lower_bound: ", + quant_min_lower_bound, + " actual quant_min: ", + quant_min); + + VK_CHECK_COND( + quant_max <= quant_max_upper_bound, + "quant_max out of bound for dtype, expected quant_max_upper_bound: ", + quant_max_upper_bound, + " actual quant_max: ", + quant_max); +} + +// +// Reference Implementation +// + +/* + * Reference implementation of quantize_per_tensor + */ +at::Tensor quantize_per_tensor_reference_impl( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + // Create output tensor with the target dtype + at::Tensor out = at::empty_like(input, dtype); + + // Quantize the input tensor + float inv_scale = 1.0 / scale; + + // Iterate through the tensor and quantize each element + at::Tensor float_input = input.to(at::kFloat); + at::Tensor float_values = float_input.flatten(); + + auto out_flat = out.flatten(); + + for (int i = 0; i < float_values.numel(); i++) { + float value = float_values[i].item(); + int64_t qvalue = zero_point + std::nearbyint(inv_scale * value); + + qvalue = std::max(qvalue, quant_min); + qvalue = std::min(qvalue, quant_max); + + if (dtype == at::kByte) { + out_flat[i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + out_flat[i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + out_flat[i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + out_flat[i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + out_flat[i] = static_cast(qvalue); + } + } + + return out.reshape(input.sizes()); +} + +/* + * Reference implementation of quantize_per_token + */ +at::Tensor quantize_per_token_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + // Create output tensor with the target dtype + at::Tensor out = at::empty_like(input, dtype); + + // Calculate number of tokens + int num_tokens = 1; + for (int i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + + // Verify that the number of tokens matches the size of scale and zero_point + // tensors + assert(num_tokens == scale.numel()); + assert(num_tokens == zero_point.numel()); + + // Reshape input to [num_tokens, last_dim] + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)}); + + // Quantize each token separately + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + // Use float for scale since Vulkan doesn't support double + float token_scale = scale[token_idx].item(); + // Use int for zero_point since Vulkan doesn't support int64_t + int token_zero_point = zero_point[token_idx].item(); + + float inv_scale = 1.0 / token_scale; + + // Quantize the token + for (int i = 0; i < input.size(-1); i++) { + float value = reshaped_input[token_idx][i].item(); + int qvalue = token_zero_point + std::nearbyint(inv_scale * value); + + qvalue = std::max(qvalue, quant_min); + qvalue = std::min(qvalue, quant_max); + + if (dtype == at::kByte) { + reshaped_out[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + reshaped_out[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + reshaped_out[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + reshaped_out[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + reshaped_out[token_idx][i] = static_cast(qvalue); + } + } + } + + return out; +} + +// Forward declaration of implementation functions +void test_vulkan_quantize_per_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +void test_vulkan_quantize_per_token_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_quantize_per_tensor( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + // Test with buffer storage + test_vulkan_quantize_per_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // If the in_dtype is a double, convert to float for texture implementation + // since they don't support 64bit as inputs + if (in_dtype == at::kDouble) { + in_dtype = at::kFloat; + } + + // Test with texture storage + test_vulkan_quantize_per_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_quantize_per_token( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + // Test with buffer storage + test_vulkan_quantize_per_token_impl( + input_sizes, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // If the in_dtype is a double, convert to float for texture implementation + // since they don't support 64bit as inputs + if (in_dtype == at::kDouble) { + in_dtype = at::kFloat; + } + + // Test with texture storage + test_vulkan_quantize_per_token_impl( + input_sizes, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +void test_reference_quantize_per_tensor( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + check_quantize_args(quant_min, quant_max, dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Fill with a simple pattern: values from 0 to 1 in steps + float step = 1.0f / (input.numel() - 1); + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + flat_input[i] = i * step; + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + scale = scale < eps ? eps : scale; + + // Get reference output + at::Tensor reference_out = quantize_per_tensor_reference_impl( + input, scale, zero_point, quant_min, quant_max, dtype); + + // Get implementation output + at::Tensor impl_out = torch::executor::native::quantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype); + + // Convert to int for consistent display regardless of underlying type + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor impl_int = impl_out.to(at::kInt); + + const bool output_correct = at::equal(reference_int, impl_int); + if (!output_correct) { + at::Tensor diffs = at::abs(reference_int - impl_int); + + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "my_reference:" << std::endl; + std::cout << impl_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_quantize_per_tensor_impl( + const std::vector& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + check_quantize_args(quant_min, quant_max, dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + scale = scale < eps ? eps : scale; + + // Get reference output + at::Tensor reference_out = torch::executor::native::quantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype); + + // Build Vulkan quantize_per_tensor graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + const ValueRef r_scale = graph.add_scalar(scale); + const ValueRef r_zero_point = graph.add_scalar(zero_point); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(dtype), out_storage); + + VK_GET_OP_FN("quantize_per_tensor.default") + (graph, + { + r_input.value, + r_scale, + r_zero_point, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan quantize_per_tensor + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + // For quantized types, we need to compare the actual integer values + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor vk_int = vk_out.to(at::kInt); + + const bool output_correct = at::allclose(reference_int, vk_int); + if (!output_correct) { + at::Tensor diffs = at::abs(reference_int - vk_int); + + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_tensor_float_to_int8) { + test_reference_quantize_per_tensor( + {2, 3, 4}, // input sizes + 0.1, // scale + 0, // zero_point + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_tensor_float_to_int32) { + test_reference_quantize_per_tensor( + {2, 3, 4}, // input sizes + 0.04, // scale + 5, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_tensor_half_to_uint8) { + test_reference_quantize_per_tensor( + {2, 3, 4}, // input sizes + 0.2, // scale + 2, // zero_point + 0, // quant_min + 255, // quant_max + at::kHalf, + at::kByte); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_tensor_half_to_int32) { + test_reference_quantize_per_tensor( + {2, 3, 4}, // input sizes + 0.01, // scale + 1, // zero_point + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kHalf, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_float_to_uint8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor( + {5, 3, 2, 4}, // input sizes + 0.01, // scale + 1, // zero_point + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_float_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor( + {5, 3, 2, 4}, // input sizes + 0.01, // scale + 1, // zero_point + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_float_to_int32) { + test_vulkan_quantize_per_tensor( + {5, 3, 2, 4}, // input sizes + 0.01, // scale + 1, // zero_point + -2147483648, // quant_min + 2147483647, // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_float_to_int32_small_scale) { + test_vulkan_quantize_per_tensor( + {2, 8, 1, 3}, // input sizes + 0.0, // scale + 20, // zero_point + -2147483648, // quant_min + 2147483647, // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_half_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor( + {2, 3}, // input sizes + 0.01, // scale + 1, // zero_point + -128, // quant_min + 127, // quant_max + at::kHalf, // input dtype + at::kChar); // output dtype +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_double_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor( + {2, 3}, // input sizes + 0.01, // scale + 1, // zero_point + -128, // quant_min + 127, // quant_max + at::kDouble, // input dtype + at::kChar); // output dtype +} + +void test_reference_quantize_per_token( + const std::vector& input_sizes, + const std::vector& pre_scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + check_quantize_args(quant_min, quant_max, dtype); + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Fill with a simple pattern: values from 0 to 1 in steps + float step = 1.0 / (input.numel() - 1); + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + flat_input[i] = i * step; + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Calculate number of tokens + int num_tokens = 1; + for (int i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + + // Verify that the number of tokens matches the size of scales and zero_points + ASSERT_EQ(num_tokens, pre_scales.size()); + ASSERT_EQ(num_tokens, zero_points.size()); + + std::vector scales = pre_scales; + for (auto& s : scales) { + s = s < eps ? eps : s; + } + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor reference_out = quantize_per_token_reference_impl( + input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); + + // Get implementation output + at::Tensor impl_out = torch::executor::native::quantize_per_token_aten( + input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); + + // Convert to int for consistent display regardless of underlying type + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor impl_int = impl_out.to(at::kInt); + + const bool output_correct = at::equal(reference_int, impl_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "my_reference:" << std::endl; + std::cout << impl_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_quantize_per_token_impl( + const std::vector& input_sizes, + const std::vector& pre_scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + check_quantize_args(quant_min, quant_max, dtype); + int num_tokens = 1; + for (int i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + ASSERT_EQ(num_tokens, pre_scales.size()); + ASSERT_EQ(num_tokens, zero_points.size()); + + std::vector scales = pre_scales; + for (auto& s : scales) { + s = s < eps ? eps : s; + } + + // Create input tensor with random values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output to show what we would compare against + at::Tensor reference_out = torch::executor::native::quantize_per_token_aten( + input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(dtype), out_storage); + + VK_GET_OP_FN("quantize_per_token.default") + (graph, + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Convert scale tensor to float and copy to GPU + at::Tensor scale_float = scale_tensor.to(at::kFloat); + graph.copy_into_staging( + r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); + + // Convert zero_point tensor to int and copy to GPU + at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); + graph.copy_into_staging( + r_zero_point.staging, + zero_point_int.const_data_ptr(), + zero_point_int.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor vk_int = vk_out.to(at::kInt); + + const bool output_correct = at::allclose(reference_int, vk_int); + if (!output_correct) { + at::Tensor diffs = at::abs(reference_int - vk_int); + + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_token_float_to_int8) { + std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; + std::vector zero_points = {1, 2, 3, 0, -1, -2}; + + test_reference_quantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_token_float_to_int32) { + std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; + std::vector zero_points = {1, 2, 3, 0, -1, -2}; + + test_reference_quantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_token_half_to_int32) { + std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; + std::vector zero_points = {1, 2, 3, 0, -1, -2}; + + test_reference_quantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + std::numeric_limits::min(), // quant_min + std::numeric_limits::max(), // quant_max + at::kHalf, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_token_half_to_uint8) { + std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; + std::vector zero_points = {1, 2, 3, 0, -1, -2}; + + test_reference_quantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kHalf, + at::kByte); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_uint8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = { + -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; + std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; + + test_vulkan_quantize_per_token( + {5, 2, 4}, // input sizes (5*2=10 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kFloat, + at::kByte); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = { + -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; + std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; + + test_vulkan_quantize_per_token( + {5, 2, 4}, // input sizes (5 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_int32) { + std::vector scales = { + -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; + std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; + + test_vulkan_quantize_per_token( + {5, 2, 4}, // input sizes (5*2=10 tokens) + scales, + zero_points, + -2147483648, // quant_min + 2147483647, // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_int32_small_scales) { + std::vector scales = { + 0, + 2.9387358770557188e-39f, + 1.40129846e-45f, + 1.17549435e-38f, + 0.0000000000001}; + std::vector zero_points = {20, -10, 15, 200, 50}; + + test_vulkan_quantize_per_token( + {5, 2}, // input sizes (3 tokens) + scales, + zero_points, + -2147483648, // quant_min + 2147483647, // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_float_to_uint8_many_tokens) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales(18, 0.1); + std::vector zero_points(18, 5); + + // Alternate scale values + for (size_t i = 0; i < scales.size(); i++) { + scales[i] = (i % 2 == 0) ? 0.3 : -0.5; + } + + test_vulkan_quantize_per_token( + {3, 3, 2, 3}, // input sizes (3*3*2=18 tokens) + scales, + zero_points, + 0, // quant_min + 125, // quant_max + at::kFloat, + at::kByte); +} + +TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_vulkan_quantize_per_token( + {2, 2}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kHalf, // input dtype + at::kChar); // output dtype +} + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_double_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_vulkan_quantize_per_token( + {2, 2}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kDouble, // input dtype + at::kChar); // output dtype +} diff --git a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp index 534bb577e7a..eebbb89ab40 100644 --- a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp +++ b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp @@ -14,6 +14,8 @@ #include #include +#include "test_utils.h" + #include // @@ -55,26 +57,6 @@ std::pair rotary_embedding_impl( // Test functions // -vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { - using namespace vkcompute; - switch (at_scalartype) { - case c10::kFloat: - return vkapi::kFloat; - case c10::kHalf: - return vkapi::kHalf; - case c10::kInt: - return vkapi::kInt; - case c10::kLong: - return vkapi::kInt; - case c10::kChar: - return vkapi::kChar; - case c10::kByte: - return vkapi::kByte; - default: - VK_THROW("Unsupported at::ScalarType!"); - } -} - void test_reference( const int n_heads = 4, const int n_kv_heads = 2, diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp index 772039eda6a..79b679674a5 100644 --- a/backends/vulkan/test/op_tests/sdpa_test.cpp +++ b/backends/vulkan/test/op_tests/sdpa_test.cpp @@ -18,6 +18,8 @@ #include #include +#include "test_utils.h" + #include #include @@ -261,24 +263,6 @@ void test_reference_sdpa( } } -vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { - using namespace vkcompute; - switch (at_scalartype) { - case c10::kFloat: - return vkapi::kFloat; - case c10::kHalf: - return vkapi::kHalf; - case c10::kInt: - return vkapi::kInt; - case c10::kLong: - return vkapi::kInt; - case c10::kChar: - return vkapi::kChar; - default: - VK_THROW("Unsupported at::ScalarType!"); - } -} - void test_vulkan_sdpa( const int start_input_pos, const int base_sequence_len, diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 5c9afa40762..0d014c7ef29 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -142,6 +142,28 @@ def define_common_targets(is_fbcode = False): platforms = get_platforms(), ) + runtime.cxx_library( + name = "test_utils", + srcs = [ + "test_utils.cpp", + ], + headers = [ + "test_utils.h", + ], + exported_headers = [ + "test_utils.h", + ], + deps = [ + "//executorch/backends/vulkan:vulkan_graph_runtime", + "//executorch/runtime/core/exec_aten:lib", + runtime.external_dep_location("libtorch"), + ], + visibility = [ + "//executorch/backends/vulkan/test/op_tests/...", + "@EXECUTORCH_CLIENTS", + ], + ) + define_test_targets( "compute_graph_op_tests", src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]" @@ -150,9 +172,47 @@ def define_common_targets(is_fbcode = False): define_test_targets( "sdpa_test", extra_deps = [ + ":test_utils", "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", "//executorch/extension/tensor:tensor", ] ) - define_test_targets("linear_weight_int4_test") - define_test_targets("rotary_embedding_test") + define_test_targets( + "quantize_test", + extra_deps = [ + ":test_utils", + "//executorch/kernels/quantized/cpu:op_quantize", + "//executorch/extension/tensor:tensor", + "//executorch/extension/aten_util:aten_bridge", + ] + ) + define_test_targets( + "dequantize_test", + extra_deps = [ + ":test_utils", + "//executorch/kernels/quantized/cpu:op_dequantize", + "//executorch/extension/tensor:tensor", + "//executorch/extension/aten_util:aten_bridge", + ] + ) + define_test_targets( + "choose_qparams_test", + extra_deps = [ + ":test_utils", + "//executorch/kernels/quantized/cpu:op_choose_qparams", + "//executorch/extension/tensor:tensor", + "//executorch/extension/aten_util:aten_bridge", + ] + ) + define_test_targets( + "linear_weight_int4_test", + extra_deps = [ + ":test_utils", + ] + ) + define_test_targets( + "rotary_embedding_test", + extra_deps = [ + ":test_utils", + ] + ) diff --git a/backends/vulkan/test/op_tests/test_utils.cpp b/backends/vulkan/test/op_tests/test_utils.cpp new file mode 100644 index 00000000000..c5702abd079 --- /dev/null +++ b/backends/vulkan/test/op_tests/test_utils.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "test_utils.h" + +#include + +executorch::aten::ScalarType at_scalartype_to_et_scalartype( + at::ScalarType dtype) { + using ScalarType = executorch::aten::ScalarType; + switch (dtype) { + case at::kByte: + return ScalarType::Byte; + case at::kChar: + return ScalarType::Char; + case at::kShort: + return ScalarType::Short; + case at::kInt: + return ScalarType::Int; + case at::kLong: + return ScalarType::Long; + case at::kHalf: + return ScalarType::Half; + case at::kFloat: + return ScalarType::Float; + case at::kDouble: + return ScalarType::Double; + default: + throw std::runtime_error("Unsupported dtype"); + } +} + +std::string scalar_type_name(c10::ScalarType dtype) { + switch (dtype) { + case c10::kLong: + return "c10::kLong"; + case c10::kShort: + return "c10::kShort"; + case c10::kComplexHalf: + return "c10::kComplexHalf"; + case c10::kComplexFloat: + return "c10::kComplexFloat"; + case c10::kComplexDouble: + return "c10::kComplexDouble"; + case c10::kBool: + return "c10::kBool"; + case c10::kQInt8: + return "c10::kQInt8"; + case c10::kQUInt8: + return "c10::kQUInt8"; + case c10::kQInt32: + return "c10::kQInt32"; + case c10::kBFloat16: + return "c10::kBFloat16"; + case c10::kQUInt4x2: + return "c10::kQUInt4x2"; + case c10::kQUInt2x4: + return "c10::kQUInt2x4"; + case c10::kFloat: + return "c10::kFloat"; + case c10::kHalf: + return "c10::kHalf"; + case c10::kInt: + return "c10::kInt"; + case c10::kChar: + return "c10::kChar"; + case c10::kByte: + return "c10::kByte"; + case c10::kDouble: + return "c10::kDouble"; + case c10::kUInt16: + return "c10::kUInt16"; + case c10::kBits16: + return "c10::kBits16"; + default: + return "Unknown(" + std::to_string(static_cast(dtype)) + ")"; + } +} + +vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { + using namespace vkcompute; + switch (at_scalartype) { + case c10::kHalf: + return vkapi::kHalf; + case c10::kFloat: + return vkapi::kFloat; + case c10::kDouble: + return vkapi::kDouble; + case c10::kInt: + return vkapi::kInt; + case c10::kLong: + // No support for 64-bit integers + return vkapi::kInt; + case c10::kChar: + return vkapi::kChar; + case c10::kByte: + return vkapi::kByte; + case c10::kShort: + return vkapi::kShort; + case c10::kUInt16: + return vkapi::kUInt16; + default: + VK_THROW( + "Unsupported at::ScalarType: ", + scalar_type_name(at_scalartype), + " (", + static_cast(at_scalartype), + ")"); + } +} diff --git a/backends/vulkan/test/op_tests/test_utils.h b/backends/vulkan/test/op_tests/test_utils.h new file mode 100644 index 00000000000..369767007e0 --- /dev/null +++ b/backends/vulkan/test/op_tests/test_utils.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +/** + * Convert at::ScalarType to executorch::ScalarType + */ +executorch::aten::ScalarType at_scalartype_to_et_scalartype( + at::ScalarType dtype); + +/** + * Get the string name of a c10::ScalarType for better error messages + */ +std::string scalar_type_name(c10::ScalarType dtype); + +/** + * Convert c10::ScalarType to vkcompute::vkapi::ScalarType + */ +vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype); diff --git a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py index 65bb959f6d1..a054fdf1a19 100644 --- a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py @@ -177,6 +177,8 @@ def generate_benchmark_fixture(self) -> str: vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {{ switch (at_scalartype) {{ + case c10::kDouble: + return vkapi::kDouble; case c10::kFloat: return vkapi::kFloat; case c10::kHalf: @@ -187,6 +189,8 @@ def generate_benchmark_fixture(self) -> str: return vkapi::kInt; case c10::kChar: return vkapi::kChar; + case c10::kBool: + return vkapi::kBool; default: VK_THROW("Unsupported at::ScalarType!"); }} diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py index 4f0d2ff11ef..e7cf5ba92a5 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py @@ -110,6 +110,8 @@ def gen_parameterization(self) -> str: vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { switch (at_scalartype) { + case c10::kDouble: + return vkapi::kDouble; case c10::kFloat: return vkapi::kFloat; case c10::kHalf: diff --git a/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml b/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml index a00bba2bc5a..69587bd38d0 100644 --- a/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml +++ b/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml @@ -6,7 +6,7 @@ warp_size: parameter_names_with_default_values: - DTYPE: int + DTYPE: int32 STORAGE: buffer generate_variant_forall: METHOD: diff --git a/backends/xnnpack/README.md b/backends/xnnpack/README.md index 411bec99d79..6e6be7ddb4c 100644 --- a/backends/xnnpack/README.md +++ b/backends/xnnpack/README.md @@ -105,6 +105,7 @@ mkdir cmake-out cmake \ -DCMAKE_INSTALL_PREFIX=cmake-out \ -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ diff --git a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py index 0d4e69e9f8e..1d824d234ee 100644 --- a/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py +++ b/backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py @@ -91,18 +91,10 @@ def is_nchw_node(self, node: torch.fx.Node) -> bool: return not self.is_nhwc_node(node) def requires_nhwc_input(self, node: torch.fx.Node) -> bool: - return ( - node.target in self.memory_sensitive_ops_nhwc - or node.name == "output" - and not node.args[0][0].meta["val"].is_contiguous() - ) + return node.target in self.memory_sensitive_ops_nhwc def requires_nchw_inputs(self, node: torch.fx.Node) -> bool: - return ( - node.target in self.memory_sensitive_ops_nchw - or node.name == "output" - and node.args[0][0].meta["val"].is_contiguous() - ) + return node.target in self.memory_sensitive_ops_nchw def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool: # There are two conditions that must be met for a node to be able to @@ -380,18 +372,21 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901 # This node has no inputs so we don't need to change anything continue - if self.requires_nhwc_input(node): + # Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes + if node.op == "output": + out_tuple = node.args[0] + for out_node in out_tuple: + if out_node.meta["val"].is_contiguous(): + self.input_to_nchw(graph_module, out_node, node) + else: + self.input_to_nhwc(graph_module, out_node, node) + elif self.requires_nhwc_input(node): # Nodes which enter this branch are ones that require their # first input to be nhwc. This makes this node's output nhwc too - # Currently, all nodes like this should have all of their other - # inputs as nchw, so fail if this is not true - if node.name == "output": - self.input_to_nhwc(graph_module, node.args[0][0], node) - else: - self.input_to_nhwc(graph_module, node.args[0], node) - - for input_node in node.all_input_nodes[1:]: - if self.is_nhwc_node(input_node): + + self.input_to_nhwc(graph_module, node.args[0], node) + for input_node in node.all_input_nodes: + if input_node.op == "placeholder" and self.is_nhwc_node(input_node): raise AssertionError( f"Expected {input_node} to be NCHW in channels last reshape pass" ) diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index ec07502de54..35ce639978d 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -50,5 +50,6 @@ op_static_constant_pad, op_static_resize_bilinear_2d, op_sub, + op_tanh, op_to_copy, ) diff --git a/backends/xnnpack/operators/op_tanh.py b/backends/xnnpack/operators/op_tanh.py new file mode 100644 index 00000000000..6031839eceb --- /dev/null +++ b/backends/xnnpack/operators/op_tanh.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNGraph, + XNNTanh, + XNode, +) +from executorch.backends.xnnpack.utils.utils import get_input_node + + +@register_node_visitor +class TanhVisitor(NodeVisitor): + target = "aten.tanh.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) + + # input + input_id = vals_to_ids[get_input_node(node, 0)] + + # output + output_id = vals_to_ids[node] + + ser_node = XNode( + xnode_union=XNNTanh( + input_id=input_id, + output_id=output_id, + flags=0, + ), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index 553b10f60d1..207d2cfd713 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -49,6 +49,7 @@ SoftmaxConfig, SquareRootConfig, SubConfig, + TanhConfig, UpsampleBilinear2dConfig, ) from executorch.backends.xnnpack.partition.config.node_configs import ( @@ -99,6 +100,7 @@ PreluConfig, ReciprocalSquareRootConfig, ReLUConfig, + TanhConfig, # SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails SigmoidConfig, SliceCopyConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 46922e47010..e7e298053c6 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -371,6 +371,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] +class TanhConfig(GenericNodePartitionerConfig): + target_name = "tanh.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32] + + class MeanDimConfig(GenericNodePartitionerConfig): target_name = "mean.dim" diff --git a/backends/xnnpack/partition/configs.py b/backends/xnnpack/partition/configs.py index 59a70d64a76..246f571c9c8 100644 --- a/backends/xnnpack/partition/configs.py +++ b/backends/xnnpack/partition/configs.py @@ -66,6 +66,7 @@ exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten.log.default, exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.tanh.default, ] SUPPORTED_MODULES = [ diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer.py b/backends/xnnpack/quantizer/xnnpack_quantizer.py index 130eda03f88..3c82a65ad71 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer.py @@ -251,6 +251,15 @@ class QuantPattern: torch.ops.aten.convolution.default, } +CONV_TRANSPOSE_TARGETS = { + torch.ops.aten.conv_transpose1d, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d, + torch.ops.aten.conv_transpose3d.input, +} + LINEAR_TARGETS = { torch.ops.aten.linear.default, } @@ -269,14 +278,14 @@ class XNNPACKQuantizer(Quantizer): SUPPORTED_PATTERNS = [ QuantPattern("conv_bn_relu", False, True, CONV_TARGETS), QuantPattern("conv_bn", False, True, CONV_TARGETS), - QuantPattern("conv_transpose_bn_relu", False, True, CONV_TARGETS), - QuantPattern("conv_transpose_bn", False, True, CONV_TARGETS), + QuantPattern("conv_transpose_bn_relu", False, True, CONV_TRANSPOSE_TARGETS), + QuantPattern("conv_transpose_bn", False, True, CONV_TRANSPOSE_TARGETS), QuantPattern("linear_relu", False, False, LINEAR_TARGETS), QuantPattern("linear", True, False, LINEAR_TARGETS), QuantPattern("conv", True, False, CONV_TARGETS), - QuantPattern("conv_transpose", False, False, CONV_TARGETS), + QuantPattern("conv_transpose", True, False, CONV_TRANSPOSE_TARGETS), QuantPattern("conv_relu", False, False, CONV_TARGETS), - QuantPattern("conv_transpose_relu", False, False, CONV_TARGETS), + QuantPattern("conv_transpose_relu", False, False, CONV_TRANSPOSE_TARGETS), QuantPattern("adaptive_avg_pool2d", False, False, ADAPTIVE_AVG_POOL2D_TARGETS), QuantPattern("add_relu", False, False, ADD_TARGETS), QuantPattern("add", False, False, ADD_TARGETS), diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index 0dcfb4484ed..3d687d0b513 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -4,7 +4,10 @@ import torch import torch.nn.functional as F -from executorch.backends.xnnpack.utils.utils import is_depthwise_conv +from executorch.backends.xnnpack.utils.utils import ( + get_groups_from_conv, + is_depthwise_conv, +) from torch._subclasses import FakeTensor from torch.fx import Node from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( @@ -65,6 +68,28 @@ def decorator(annotator: AnnotatorType) -> None: return decorator +def change_quantization_config( + original_qspec, + dtype=None, + quant_min=None, + quant_max=None, + qscheme=None, + ch_axis=None, + is_dynamic=None, + observer_or_fake_quant_ctr=None, +): + return QuantizationSpec( + dtype=dtype or original_qspec.dtype, + quant_min=quant_min or original_qspec.quant_min, + quant_max=quant_max or original_qspec.quant_max, + qscheme=qscheme or original_qspec.qscheme, + ch_axis=ch_axis or original_qspec.ch_axis, + is_dynamic=is_dynamic or original_qspec.is_dynamic, + observer_or_fake_quant_ctr=observer_or_fake_quant_ctr + or original_qspec.observer_or_fake_quant_ctr, + ) + + def is_relu_node(node: Node) -> bool: """ Check if a given node is a relu node @@ -231,6 +256,9 @@ def _do_annotate_conv( if is_relu_node(user): continue + # Tracks conditions for whether or not to skip + skip = False + input_qspec_map = {} input_act = conv_node.args[0] assert isinstance(input_act, Node) @@ -238,24 +266,34 @@ def _do_annotate_conv( weight = conv_node.args[1] assert isinstance(weight, Node) - input_qspec_map[weight] = get_weight_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + num_groups = get_groups_from_conv(conv_node) - # Only annotate dynamically quantized conv if it's 2D and not depthwise - if ( + # skip if transposed conv has more than 1 group + skip = skip or (is_conv_transpose and num_groups != 1) + print(f"{skip} conv transpose and num_groups") + + if is_conv_transpose: + # transposed convs per output channel quantization + weight_qspec = change_quantization_config(weight_qspec, ch_axis=1) + + input_qspec_map[weight] = weight_qspec + is_dynamic = ( quantization_config and quantization_config.input_activation and quantization_config.input_activation.is_dynamic - ): + ) + + # Only annotate dynamically quantized conv if it's 2D and not depthwise + if is_dynamic: weight_val = weight.meta.get("val", None) weight_shape = getattr(weight_val, "shape", None) - # Skip if not a 4D weight tensor (i.e. not conv2d) - if weight_shape is not None and len(weight_shape) != 4: - continue - + skip = skip or (weight_shape is not None and len(weight_shape) != 4) # Skip if depthwise (default to groups=1 since it's not an arg) - if is_depthwise_conv(weight_shape, 1, is_conv_transpose): - continue + skip = skip or ( + not is_conv_transpose and is_depthwise_conv(weight_shape, 1, False) + ) # adding weight node to the partition as well partition = [conv_node, conv_node.args[1]] @@ -265,7 +303,7 @@ def _do_annotate_conv( input_qspec_map[bias] = get_bias_qspec(quantization_config) partition.append(bias) - if _is_annotated(partition): + if _is_annotated(partition) or skip: continue if filter_fn and any(not filter_fn(n) for n in partition): @@ -311,7 +349,12 @@ def _do_annotate_conv_relu( weight = conv_node.args[1] assert isinstance(weight, Node) - input_qspec_map[weight] = get_weight_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + groups = get_groups_from_conv(conv_node) + if is_conv_transpose: + # transposed convs per output channel quantization + weight_qspec = change_quantization_config(weight_qspec, ch_axis=1) + input_qspec_map[weight] = weight_qspec # adding weight node to the partition as well partition = [relu_node, conv_node, conv_node.args[1]] @@ -323,6 +366,9 @@ def _do_annotate_conv_relu( if _is_annotated(partition): continue + if is_conv_transpose and groups != 1: + continue + if filter_fn and any(not filter_fn(n) for n in partition): continue diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 312cbc17b95..b724ab1a9d9 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -601,128 +602,6 @@ Error defineTensor( #define MAYBE_UNUSED(x) (void)(x) -/* -Define serialized add node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining -the tensor value -*/ -Error defineAddNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - std::pair min_max = getOutputMinMax(node); - auto graph_node = node->xnode_union_as_XNNAdd(); - xnn_status status = xnn_define_add2( - subgraph_ptr, - min_max.first, - min_max.second, - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create add node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -}; - -/* -Define Minimum operator Node into the subgraph -*/ -Error defineMinimumNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNMinimum(); - xnn_status status = xnn_define_minimum2( - subgraph_ptr, - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create minumum node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -}; - -/* -Define subtract operator Node into the subgraph -*/ -Error defineSubtractNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNSubtract(); - std::pair min_max = getOutputMinMax(node); - xnn_status status = xnn_define_subtract( - subgraph_ptr, - min_max.first, - min_max.second, - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create subtract node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -}; - -/* -Define Multiply operator Node into the subgraph -*/ -Error defineMultiplyNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNMultiply(); - std::pair min_max = getOutputMinMax(node); - xnn_status status = xnn_define_multiply2( - subgraph_ptr, - min_max.first, - min_max.second, - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create multiply node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -}; - #ifdef ENABLE_XNNPACK_KLEIDI bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) { assert(node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNConvert); @@ -843,38 +722,6 @@ Error defineFullyConnectedNode( return Error::Ok; }; -/* -Define serialized clamp node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining -the tensor value -*/ -Error defineClampNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - std::pair min_max = getOutputMinMax(node); - auto graph_node = node->xnode_union_as_XNNClamp(); - xnn_status status = xnn_define_clamp( - subgraph_ptr, - min_max.first, - min_max.second, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create hardtanh node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - /* Define serialized softmax node into the subgraph, using the remapped ids to map the serialized ids, to the new ids generated when defining @@ -903,62 +750,6 @@ Error defineSoftmaxNode( return Error::Ok; } -/* -Define serialized sigmoid node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining -the tensor value -*/ -Error defineSigmoidNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNSigmoid(); - xnn_status status = xnn_define_sigmoid( - subgraph_ptr, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create sigmoid node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Define serialized floor node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining -the tensor value -*/ -Error defineFloorNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNFloor(); - xnn_status status = xnn_define_floor( - subgraph_ptr, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create floor node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - Error defineGlobalAvgPooling2dNode( xnn_subgraph_t subgraph_ptr, const std::unordered_map& remapped_ids, @@ -1155,36 +946,6 @@ Error defineMaxPooling2dNode( return Error::Ok; } -/* -Define serialized div node into the subgraph -*/ -Error defineDivNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNDiv(); - std::pair min_max = getOutputMinMax(node); - xnn_status status = xnn_define_divide( - subgraph_ptr, - min_max.first, - min_max.second, - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create div node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - /* Define serialized static transpose node into the subgraph, using the remapped ids to map the serialized ids, to the new ids generated when defining the @@ -1402,20 +1163,20 @@ Error defineArgMaxPooling2dNode( } /* -Define serialized square root node into the subgraph, using the remapped ids +Define serialized tanh node into the subgraph, using the remapped ids to map the serialized ids, to the new ids generated when defining the tensor value */ -Error defineSquareRootNode( +Error defineTanhNode( xnn_subgraph_t subgraph_ptr, const std::unordered_map& remapped_ids, const NodePtr node, const fb_xnnpack::XNNGraph* graph) noexcept { MAYBE_UNUSED(graph); - auto graph_node = node->xnode_union_as_XNNSquareRoot(); + auto graph_node = node->xnode_union_as_XNNTanh(); - xnn_status status = xnn_define_square_root( + xnn_status status = xnn_define_tanh( subgraph_ptr, remapped_ids.at(graph_node->input_id()), remapped_ids.at(graph_node->output_id()), @@ -1424,7 +1185,7 @@ Error defineSquareRootNode( ET_CHECK_OR_RETURN_ERROR( status == xnn_status_success, Internal, - "Failed to create square root node %i with code: %s", + "Failed to create tanh node %i with code: %s", node->debug_handle(), xnn_status_to_string(status)); @@ -1432,29 +1193,30 @@ Error defineSquareRootNode( } /* -Define serialized square root node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value +Defines serialized prelu node into the subgraph, +using the remapped ids to map the serialized ids, +to the new ids generated when defining the tensor value */ -Error defineReciprocalSquareRootNode( +Error definePReLUNode( xnn_subgraph_t subgraph_ptr, const std::unordered_map& remapped_ids, const NodePtr node, const fb_xnnpack::XNNGraph* graph) noexcept { MAYBE_UNUSED(graph); - auto graph_node = node->xnode_union_as_XNNReciprocalSquareRoot(); + auto graph_node = node->xnode_union_as_XNNPReLU(); - xnn_status status = xnn_define_reciprocal_square_root( + xnn_status status = xnn_define_prelu( subgraph_ptr, - remapped_ids.at(graph_node->input_id()), + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->input2_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); ET_CHECK_OR_RETURN_ERROR( status == xnn_status_success, Internal, - "Failed to create reciprocal square root node %i with code: %s", + "Failed to create prelu node %i with code: %s", node->debug_handle(), xnn_status_to_string(status)); @@ -1462,29 +1224,31 @@ Error defineReciprocalSquareRootNode( } /* -Define serialized log node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value +Defines serialized concatenate2 node into the subgraph, +using the remapped ids to map the serialized ids, +to the new ids generated when defining the tensor value */ -Error defineLogNode( +Error defineConcatenate2Node( xnn_subgraph_t subgraph_ptr, const std::unordered_map& remapped_ids, const NodePtr node, const fb_xnnpack::XNNGraph* graph) noexcept { MAYBE_UNUSED(graph); - auto graph_node = node->xnode_union_as_XNNLog(); + auto graph_node = node->xnode_union_as_XNNConcatenate2(); - xnn_status status = xnn_define_log( + xnn_status status = xnn_define_concatenate2( subgraph_ptr, - remapped_ids.at(graph_node->input_id()), + graph_node->axis(), + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->input2_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); ET_CHECK_OR_RETURN_ERROR( status == xnn_status_success, Internal, - "Failed to create log node %i with code: %s", + "Failed to create cat2 node %i with code: %s", node->debug_handle(), xnn_status_to_string(status)); @@ -1492,29 +1256,32 @@ Error defineLogNode( } /* -Define serialized gelu node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value +Defines serialized concatenate3 node into the subgraph, +using the remapped ids to map the serialized ids, +to the new ids generated when defining the tensor value */ -Error defineGeluNode( +Error defineConcatenate3Node( xnn_subgraph_t subgraph_ptr, const std::unordered_map& remapped_ids, const NodePtr node, const fb_xnnpack::XNNGraph* graph) noexcept { MAYBE_UNUSED(graph); - auto graph_node = node->xnode_union_as_XNNGelu(); + auto graph_node = node->xnode_union_as_XNNConcatenate3(); - xnn_status status = xnn_define_gelu( + xnn_status status = xnn_define_concatenate3( subgraph_ptr, - remapped_ids.at(graph_node->input_id()), + graph_node->axis(), + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->input2_id()), + remapped_ids.at(graph_node->input3_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); ET_CHECK_OR_RETURN_ERROR( status == xnn_status_success, Internal, - "Failed to create gelu node %i with code: %s", + "Failed to create cat3 node %i with code: %s", node->debug_handle(), xnn_status_to_string(status)); @@ -1522,368 +1289,33 @@ Error defineGeluNode( } /* -Define serialized ceiling node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value +Defines serialized concatenate4 node into the subgraph, +using the remapped ids to map the serialized ids, +to the new ids generated when defining the tensor value */ -Error defineCeilingNode( +Error defineConcatenate4Node( xnn_subgraph_t subgraph_ptr, const std::unordered_map& remapped_ids, const NodePtr node, const fb_xnnpack::XNNGraph* graph) noexcept { MAYBE_UNUSED(graph); - auto graph_node = node->xnode_union_as_XNNCeiling(); + auto graph_node = node->xnode_union_as_XNNConcatenate4(); - xnn_status status = xnn_define_ceiling( + xnn_status status = xnn_define_concatenate4( subgraph_ptr, - remapped_ids.at(graph_node->input_id()), + graph_node->axis(), + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->input2_id()), + remapped_ids.at(graph_node->input3_id()), + remapped_ids.at(graph_node->input4_id()), remapped_ids.at(graph_node->output_id()), graph_node->flags()); ET_CHECK_OR_RETURN_ERROR( status == xnn_status_success, Internal, - "Failed to create ceiling node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Define serialized hardswish node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value -*/ -Error defineHardswishNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNHardswish(); - - xnn_status status = xnn_define_hardswish( - subgraph_ptr, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create hardswish node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Define serialized leaky relu node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value -*/ -Error defineLeakyReLUNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNLeakyReLU(); - - xnn_status status = xnn_define_leaky_relu( - subgraph_ptr, - graph_node->negative_slope(), - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create leaky relu node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Define serialized maximum node into the subgraph, using the remapped ids -to map the serialized ids, to the new ids generated when defining the -tensor value -*/ -Error defineMaximumNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNMaximum(); - - xnn_status status = xnn_define_maximum2( - subgraph_ptr, - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create maximum node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Define Negate node into subgraph, using the remapped ids to map the -serialized ids, to the new ids generated when defining the tensor value -*/ -Error defineNegateNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNNegate(); - - xnn_status status = xnn_define_negate( - subgraph_ptr, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create negate node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Defines square node into subgraph using the remapped ids to map the -serialized ids to the new ids generated when defining the tensor value -*/ -Error defineSquareNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNSquare(); - - xnn_status status = xnn_define_square( - subgraph_ptr, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create square node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Defines square node into subgraph using the remapped ids to map the -serialized ids to the new ids generated when defining the tensor value -*/ -Error defineELUNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNELU(); - - xnn_status status = xnn_define_elu( - subgraph_ptr, - graph_node->alpha(), - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create ELU node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Defines absolute value node into subgraph using the remapped ids to map the -serialized ids to the new ids generated when defining the tensor value -*/ -Error defineAbsNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNAbs(); - - xnn_status status = xnn_define_abs( - subgraph_ptr, - remapped_ids.at(graph_node->input_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create abs node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Defines serialized prelu node into the subgraph, -using the remapped ids to map the serialized ids, -to the new ids generated when defining the tensor value -*/ -Error definePReLUNode( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNPReLU(); - - xnn_status status = xnn_define_prelu( - subgraph_ptr, - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create prelu node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Defines serialized concatenate2 node into the subgraph, -using the remapped ids to map the serialized ids, -to the new ids generated when defining the tensor value -*/ -Error defineConcatenate2Node( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNConcatenate2(); - - xnn_status status = xnn_define_concatenate2( - subgraph_ptr, - graph_node->axis(), - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create cat2 node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Defines serialized concatenate3 node into the subgraph, -using the remapped ids to map the serialized ids, -to the new ids generated when defining the tensor value -*/ -Error defineConcatenate3Node( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNConcatenate3(); - - xnn_status status = xnn_define_concatenate3( - subgraph_ptr, - graph_node->axis(), - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->input3_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create cat3 node %i with code: %s", - node->debug_handle(), - xnn_status_to_string(status)); - - return Error::Ok; -} - -/* -Defines serialized concatenate4 node into the subgraph, -using the remapped ids to map the serialized ids, -to the new ids generated when defining the tensor value -*/ -Error defineConcatenate4Node( - xnn_subgraph_t subgraph_ptr, - const std::unordered_map& remapped_ids, - const NodePtr node, - const fb_xnnpack::XNNGraph* graph) noexcept { - MAYBE_UNUSED(graph); - - auto graph_node = node->xnode_union_as_XNNConcatenate4(); - - xnn_status status = xnn_define_concatenate4( - subgraph_ptr, - graph_node->axis(), - remapped_ids.at(graph_node->input1_id()), - remapped_ids.at(graph_node->input2_id()), - remapped_ids.at(graph_node->input3_id()), - remapped_ids.at(graph_node->input4_id()), - remapped_ids.at(graph_node->output_id()), - graph_node->flags()); - - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Failed to create cat4 node %i with code: %s", + "Failed to create cat4 node %i with code: %s", node->debug_handle(), xnn_status_to_string(status)); @@ -2047,6 +1479,196 @@ Error defineNotImplementedNode( fb_xnnpack::EnumNameXNodeUnion(node->xnode_union_type())); } +// Generic helper function for unary operations +Error defineGenericUnaryNode( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + uint32_t input_id, + uint32_t output_id, + uint32_t flags, + xnn_unary_operator op_type, + const union xnn_unary_params* params, + fb_xnnpack::XNodeUnion node_type, + uint32_t debug_handle) noexcept { + xnn_status status = xnn_define_unary( + subgraph_ptr, + op_type, + params, + remapped_ids.at(input_id), + remapped_ids.at(output_id), + flags); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create %s node %i with code: %s", + fb_xnnpack::EnumNameXNodeUnion(node_type), + debug_handle, + xnn_status_to_string(status)); + + return Error::Ok; +} + +// Macro for unary operations with no parameters +#define _DEFINE_UNARY_NODE_NO_PARAMS(name, op_type) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNN##name(); \ + return defineGenericUnaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node->input_id(), \ + graph_node->output_id(), \ + graph_node->flags(), \ + op_type, \ + nullptr, \ + node->xnode_union_type(), \ + node->debug_handle()); \ + } + +// Macro for unary operations with min/max parameters +#define _DEFINE_UNARY_NODE_WITH_MINMAX(name, op_type) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNN##name(); \ + std::pair min_max = getOutputMinMax(node); \ + union xnn_unary_params params = { \ + .clamp = {.min = min_max.first, .max = min_max.second}}; \ + return defineGenericUnaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node->input_id(), \ + graph_node->output_id(), \ + graph_node->flags(), \ + op_type, \ + ¶ms, \ + node->xnode_union_type(), \ + node->debug_handle()); \ + } + +// Macro for unary operations with leaky_relu parameters +#define _DEFINE_UNARY_NODE_WITH_LEAKY_RELU(name) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNNLeakyReLU(); \ + union xnn_unary_params params = { \ + .leaky_relu = {.negative_slope = graph_node->negative_slope()}}; \ + return defineGenericUnaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node->input_id(), \ + graph_node->output_id(), \ + graph_node->flags(), \ + xnn_unary_leaky_relu, \ + ¶ms, \ + node->xnode_union_type(), \ + node->debug_handle()); \ + } + +// Macro for unary operations with elu parameters +#define _DEFINE_UNARY_NODE_WITH_ELU(name) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNNELU(); \ + union xnn_unary_params params = {.elu = {.alpha = graph_node->alpha()}}; \ + return defineGenericUnaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node->input_id(), \ + graph_node->output_id(), \ + graph_node->flags(), \ + xnn_unary_elu, \ + ¶ms, \ + node->xnode_union_type(), \ + node->debug_handle()); \ + } + +// Generic helper function for binary operations +Error defineGenericBinaryNode( + xnn_subgraph_t subgraph_ptr, + const std::unordered_map& remapped_ids, + const fb_xnnpack::_XNNNode2x1* graph_node, + xnn_binary_operator op_type, + const struct xnn_binary_params* params, + fb_xnnpack::XNodeUnion node_type, + uint32_t debug_handle) noexcept { + xnn_status status = xnn_define_binary( + subgraph_ptr, + op_type, + params, + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->input2_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + + ET_CHECK_OR_RETURN_ERROR( + status == xnn_status_success, + Internal, + "Failed to create %s node %i with code: %s", + fb_xnnpack::EnumNameXNodeUnion(node_type), + debug_handle, + xnn_status_to_string(status)); + + return Error::Ok; +} + +// Macro for binary operations with min/max parameters +#define _DEFINE_BINARY_NODE_WITH_MINMAX(name, op_type) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNN##name(); \ + std::pair min_max = getOutputMinMax(node); \ + struct xnn_binary_params params = { \ + .output_min = min_max.first, .output_max = min_max.second}; \ + return defineGenericBinaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node, \ + op_type, \ + ¶ms, \ + node->xnode_union_type(), \ + node->debug_handle()); \ + } + +// Macro for binary operations without parameters +#define _DEFINE_BINARY_NODE_NO_PARAMS(name, op_type) \ + Error define##name##Node( \ + xnn_subgraph_t subgraph_ptr, \ + const std::unordered_map& remapped_ids, \ + const NodePtr node, \ + const fb_xnnpack::XNNGraph* graph) noexcept { \ + MAYBE_UNUSED(graph); \ + auto graph_node = node->xnode_union_as_XNN##name(); \ + return defineGenericBinaryNode( \ + subgraph_ptr, \ + remapped_ids, \ + graph_node, \ + op_type, \ + nullptr, \ + node->xnode_union_type(), \ + node->debug_handle()); \ + } + /* Returns the pointer to the defineNode function that handles the given XNode type @@ -2055,43 +1677,81 @@ XNode type case fb_xnnpack::XNodeUnion::XNN##name: \ return &define##name##Node; +// Unary Ops with no params +_DEFINE_UNARY_NODE_NO_PARAMS(Sigmoid, xnn_unary_sigmoid) +_DEFINE_UNARY_NODE_NO_PARAMS(Floor, xnn_unary_floor) +_DEFINE_UNARY_NODE_NO_PARAMS(SquareRoot, xnn_unary_square_root) +_DEFINE_UNARY_NODE_NO_PARAMS( + ReciprocalSquareRoot, + xnn_unary_reciprocal_square_root) +_DEFINE_UNARY_NODE_NO_PARAMS(Ceiling, xnn_unary_ceiling) +_DEFINE_UNARY_NODE_NO_PARAMS(Gelu, xnn_unary_gelu) +_DEFINE_UNARY_NODE_NO_PARAMS(Hardswish, xnn_unary_hardswish) +_DEFINE_UNARY_NODE_NO_PARAMS(Log, xnn_unary_log) +_DEFINE_UNARY_NODE_NO_PARAMS(Negate, xnn_unary_negate) +_DEFINE_UNARY_NODE_NO_PARAMS(Square, xnn_unary_square) +_DEFINE_UNARY_NODE_NO_PARAMS(Abs, xnn_unary_abs) + +// Unary Ops with min/max params +_DEFINE_UNARY_NODE_WITH_MINMAX(Clamp, xnn_unary_clamp) + +// Unary Ops with specific params +_DEFINE_UNARY_NODE_WITH_LEAKY_RELU(LeakyReLU) +_DEFINE_UNARY_NODE_WITH_ELU(ELU) + +// Binary Ops with params +_DEFINE_BINARY_NODE_WITH_MINMAX(Add, xnn_binary_add) +_DEFINE_BINARY_NODE_WITH_MINMAX(Subtract, xnn_binary_subtract) +_DEFINE_BINARY_NODE_WITH_MINMAX(Multiply, xnn_binary_multiply) +_DEFINE_BINARY_NODE_WITH_MINMAX(Div, xnn_binary_divide) + +// Binary Ops without params +_DEFINE_BINARY_NODE_NO_PARAMS(Minimum, xnn_binary_minimum) +_DEFINE_BINARY_NODE_NO_PARAMS(Maximum, xnn_binary_maximum) + DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) { switch (nodeType) { + // Binary ops _DEFINE(Add) - _DEFINE(FullyConnected) + _DEFINE(Subtract) + _DEFINE(Multiply) + _DEFINE(Div) + _DEFINE(Minimum) + _DEFINE(Maximum) + + // Unary ops _DEFINE(Softmax) + _DEFINE(SquareRoot) + _DEFINE(ReciprocalSquareRoot) + _DEFINE(Ceiling) + _DEFINE(Gelu) + _DEFINE(Hardswish) + _DEFINE(Log) + _DEFINE(Tanh) + _DEFINE(Negate) + _DEFINE(Square) + _DEFINE(Clamp) + _DEFINE(LeakyReLU) + _DEFINE(ELU) + _DEFINE(Abs) + _DEFINE(Floor) + _DEFINE(PReLU) _DEFINE(Sigmoid) + + // Others + _DEFINE(FullyConnected) _DEFINE(StaticTranspose) - _DEFINE(Clamp) _DEFINE(Conv2d) _DEFINE(ConvTranspose2d) - _DEFINE(Div) _DEFINE(StaticResizeBilinear2D) _DEFINE(StaticConstantPad) _DEFINE(AvgPooling2d) - _DEFINE(Minimum) _DEFINE(DepthwiseConv2d) _DEFINE(MaxPooling2d) - _DEFINE(Multiply) - _DEFINE(Subtract) - _DEFINE(Floor) _DEFINE(Convert) _DEFINE(GlobalAvgPooling2d) _DEFINE(StaticReshape) _DEFINE(ArgMaxPooling2d) - _DEFINE(SquareRoot) - _DEFINE(ReciprocalSquareRoot) - _DEFINE(Ceiling) - _DEFINE(Gelu) - _DEFINE(Hardswish) - _DEFINE(LeakyReLU) - _DEFINE(Log) - _DEFINE(Maximum) - _DEFINE(Negate) - _DEFINE(Square) - _DEFINE(ELU) - _DEFINE(Abs) - _DEFINE(PReLU) _DEFINE(Concatenate2) _DEFINE(Concatenate3) _DEFINE(Concatenate4) diff --git a/backends/xnnpack/serialization/runtime_schema.fbs b/backends/xnnpack/serialization/runtime_schema.fbs index a0d44327912..eea4cdb8b86 100644 --- a/backends/xnnpack/serialization/runtime_schema.fbs +++ b/backends/xnnpack/serialization/runtime_schema.fbs @@ -154,6 +154,7 @@ union XNodeUnion { XNNReciprocalSquareRoot: _XNNNode1x1, XNNLog: _XNNNode1x1, XNNGelu: _XNNNode1x1, + XNNTanh: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/schema.fbs b/backends/xnnpack/serialization/schema.fbs index eeab28154cc..ed444005c64 100644 --- a/backends/xnnpack/serialization/schema.fbs +++ b/backends/xnnpack/serialization/schema.fbs @@ -150,6 +150,7 @@ union XNodeUnion { XNNReciprocalSquareRoot: _XNNNode1x1, XNNLog: _XNNNode1x1, XNNGelu: _XNNNode1x1, + XNNTanh: _XNNNode1x1, } union XValueUnion { diff --git a/backends/xnnpack/serialization/xnnpack_graph_schema.py b/backends/xnnpack/serialization/xnnpack_graph_schema.py index dc50fb47da4..106eb6b81d9 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_schema.py +++ b/backends/xnnpack/serialization/xnnpack_graph_schema.py @@ -319,6 +319,11 @@ class XNNLog(XNNNode1x1): pass +@dataclass +class XNNTanh(XNNNode1x1): + pass + + @dataclass class XNNMaximum(XNNNode2x1): pass @@ -391,6 +396,7 @@ class XNNScaledDotProductAttention: XNNReciprocalSquareRoot, XNNLog, XNNGelu, + XNNTanh, ] diff --git a/backends/xnnpack/test/__init__.py b/backends/xnnpack/test/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backends/xnnpack/test/ops/test_conv2d.py b/backends/xnnpack/test/ops/test_conv2d.py index 92bb03c907a..2a0a82d99b6 100644 --- a/backends/xnnpack/test/ops/test_conv2d.py +++ b/backends/xnnpack/test/ops/test_conv2d.py @@ -174,14 +174,11 @@ def get_inputs(self): class Conv2dDQSeq(torch.nn.Module): - def __init__(self): + def __init__(self, transpose=False): super().__init__() - self.first = torch.nn.Conv2d( - in_channels=3, out_channels=8, kernel_size=3, padding=1 - ) - self.second = torch.nn.Conv2d( - in_channels=8, out_channels=10, kernel_size=3, padding=1 - ) + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.first = op(in_channels=3, out_channels=8, kernel_size=3, padding=1) + self.second = op(in_channels=8, out_channels=10, kernel_size=3, padding=1) def forward(self, x): y = self.first(x) @@ -192,14 +189,11 @@ def get_inputs(self): class Conv2dDQParallel(torch.nn.Module): - def __init__(self): + def __init__(self, transpose=False): super().__init__() - self.first = torch.nn.Conv2d( - in_channels=3, out_channels=8, kernel_size=3, padding=1 - ) - self.second = torch.nn.Conv2d( - in_channels=3, out_channels=8, kernel_size=3, padding=1 - ) + op = torch.nn.ConvTranspose2d if transpose else torch.nn.Conv2d + self.first = op(in_channels=3, out_channels=8, kernel_size=3, padding=1) + self.second = op(in_channels=3, out_channels=10, kernel_size=3, padding=1) def forward(self, x): first = self.first(x) @@ -221,7 +215,6 @@ def _test( conv_count=1, dtype: torch.dtype = torch.float, check_quantized=True, - delegated=True, ): # pyre-fixme[29]: `Union[torch._tensor.Tensor, # torch.nn.modules.module.Module]` is not a function. @@ -240,29 +233,20 @@ def _test( (tester.export().check_count({op: conv_count}).to_edge_transform_and_lower()) - if delegated: - ( - tester.check_not( - ["executorch_exir_dialects_edge__ops_aten_convolution_default"] - ) - .check_not( - [ - "executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default" - ] - ) - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .serialize() - .run_method_and_compare_outputs(qtol=1) + ( + tester.check_not( + ["executorch_exir_dialects_edge__ops_aten_convolution_default"] ) - else: - # need quantize ops when ops are not delegated to xnnpack - if has_quantized_ops: - ( - tester.to_executorch() - .serialize() - .run_method_and_compare_outputs(qtol=1) - ) + .check_not( + [ + "executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs(qtol=1) + ) def _test_dq( self, @@ -276,8 +260,7 @@ def _test_dq( ) DynamicallyQuantizedPartitioner = XnnpackPartitioner( - config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, - per_op_mode=True, + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, per_op_mode=True ) tester = Tester(m, m.get_inputs(), dynamic_shapes=dynamic_shapes) @@ -325,7 +308,6 @@ def test_qs8_conv2d_per_channel(self) -> None: self._test( Conv2d(transpose=transpose), quant_config=get_symmetric_quantization_config(is_per_channel=True), - delegated=not transpose, # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1 ) def test_fp32_conv2d_seq(self) -> None: @@ -360,11 +342,10 @@ def test_fp32_conv2d_depthwise(self): ) def test_qs8_conv2d_depthwise(self): - for transpose in (True, False): - self._test( - Conv2d(groups=2, in_channels=2, out_channels=6, transpose=transpose), - quant_config=get_symmetric_quantization_config(), - ) + self._test( + Conv2d(groups=2, in_channels=2, out_channels=6), + quant_config=get_symmetric_quantization_config(), + ) def test_fp32_conv2d_bn(self): class Conv2dBatchNorm(torch.nn.Module): @@ -485,7 +466,6 @@ def get_inputs(self): self._test( ConvReLU(transpose=transpose), quant_config=get_symmetric_quantization_config(is_per_channel=True), - delegated=not transpose, # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1 ) def test_qs8_conv2d_dw_relu(self): @@ -527,19 +507,14 @@ def forward(self, x): def get_inputs(self): return (torch.randn(batches, in_channels, height, width) * 11,) - for transpose in (True, False): - for per_channel_quant in (False, True): - if transpose and per_channel_quant: - continue - model = ModelConvReLU(transpose=transpose) - self._test( - model, - quant_config=get_symmetric_quantization_config( - is_per_channel=per_channel_quant - ), - # XNNPACK does not support per input channel quantization for transpose convolutions with groups > 1 - delegated=not (transpose and per_channel_quant), - ) + for per_channel_quant in (False, True): + model = ModelConvReLU() + self._test( + model, + quant_config=get_symmetric_quantization_config( + is_per_channel=per_channel_quant + ), + ) def test_qs8_conv2d_relu_seq(self): class ConvReLUSeq(torch.nn.Module): @@ -593,7 +568,7 @@ def get_inputs(self): conv_count=2, ) - def test_qs8_conv_transpose_2d_quantize_per_channel(self): + def test_qs8_conv_transpose_2d_quantize_per_channel_multi_axis(self): class PerChannelConvTranspose2d(torch.nn.Module): def __init__(self, input_channels, output_channels, groups, axis): super().__init__() @@ -662,76 +637,24 @@ def get_inputs(self): ) for groups in (1, 2): - for axis in (0, 1): - self._test( - PerChannelConvTranspose2d(3 * groups, 5 * groups, groups, axis), - quant_config=None, - conv_count=1, - delegated=axis == 1 - and groups - == 1, # xnnpack only support output channel axis quantization with groups == 1 - ) - - def test_qs8_conv_transpose_2d_dqd_f32_weights(self): - class TransposeConv2dDQDf32weights(torch.nn.Module): - def __init__(self, input_channels, output_channels, groups, axis): - super().__init__() - self.input_channels = input_channels - self.output_channels = output_channels - self.axis = axis - self.groups = groups - self.transpose = True - self.weights = torch.nn.Parameter( - torch.randn((input_channels, output_channels // groups, 4, 4)), - requires_grad=False, - ) - - axis_size = self.weights.shape[axis] - self.scale = torch.nn.Parameter(torch.ones(axis_size) * 0.12345) - self.zero_point = torch.nn.Parameter( - torch.zeros((axis_size,), dtype=torch.int64), requires_grad=False - ) - - def forward(self, x): - dequantize_input = ( - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( - x, 0.12345, 0, -127, 127, torch.int8 + for ch_axis in (1, 2): + if ch_axis == 1 and groups == 1: + self._test( + PerChannelConvTranspose2d( + 3 * groups, 5 * groups, groups, ch_axis + ), # ch_axis=0 + quant_config=None, + conv_count=1, ) - ) - x = torch.nn.functional.conv_transpose2d( - dequantize_input, self.weights, groups=self.groups - ) - - return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( - x, - 0.12345, - 0, - -127, - 127, - torch.int8, - ), - 0.12345, - 0, - -127, - 127, - torch.int8, - ) - - def get_inputs(self): - return ( - torch.randint( - low=-127, high=127, size=(3, self.input_channels, 4, 4) - ).type(dtype=torch.int8), - ) - - for groups in (1, 2): - for axis in (0, 1): - self._test( - TransposeConv2dDQDf32weights(3 * groups, 5 * groups, groups, axis), - quant_config=None, - conv_count=1, - ) + else: + with self.assertRaises(RuntimeError): + self._test( + PerChannelConvTranspose2d( + 3 * groups, 5 * groups, groups, ch_axis + ), # ch_axis=0 + quant_config=None, + conv_count=1, + ) def test_padded_output_tconv(self): class TConv2d(torch.nn.Module): @@ -761,7 +684,7 @@ def forward(self, x): (tester.export().check_count({op: conv_count}).to_edge_transform_and_lower()) - # tconv should not be offloaded to XNNPack, since output padding is not + # tconv should not be offloaded to XNNPack, since output padding is not supported ( tester.check( ["executorch_exir_dialects_edge__ops_aten_convolution_default"] @@ -794,3 +717,31 @@ def test_dq_conv2d_parallel(self) -> None: model = Conv2dDQParallel() conv_count = sum(1 for m in model.modules() if type(m) is torch.nn.Conv2d) self._test_dq(model, conv_count) + + def test_dq_conv2d_transpose(self) -> None: + model = Conv2d( + in_channels=3, + out_channels=10, + kernel_size=(3, 3), + stride=(1, 1), + padding=(0, 0), + batches=1, + width=8, + height=8, + transpose=True, + ) + self._test_dq(model) + + def test_dq_conv2d_transpose_seq(self) -> None: + model = Conv2dDQSeq(transpose=True) + conv_count = sum( + 1 for m in model.modules() if type(m) is torch.nn.ConvTranspose2d + ) + self._test_dq(model, conv_count) + + def test_dq_conv2d_transpose_parallel(self) -> None: + model = Conv2dDQParallel(transpose=True) + conv_count = sum( + 1 for m in model.modules() if type(m) is torch.nn.ConvTranspose2d + ) + self._test_dq(model, conv_count) diff --git a/backends/xnnpack/test/ops/test_div.py b/backends/xnnpack/test/ops/test_div.py index b53c59df8e1..3d9835ec56e 100644 --- a/backends/xnnpack/test/ops/test_div.py +++ b/backends/xnnpack/test/ops/test_div.py @@ -31,17 +31,20 @@ def forward(self, x): return z def _test_div(self, inputs): - ( - Tester(self.Div(), inputs) - .export() - .check_count({"torch.ops.aten.div.Tensor": 1}) - .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) - .to_executorch() - .serialize() - .run_method_and_compare_outputs() - ) + for legacy_mode in (True, False): + tester = Tester(self.Div(), inputs) + tester.export() + tester.check_count({"torch.ops.aten.div.Tensor": 1}) + if legacy_mode: + tester.to_edge() + tester.partition() + else: + tester.to_edge_transform_and_lower() + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) + tester.to_executorch() + tester.serialize() + tester.run_method_and_compare_outputs() def test_fp16_div(self): # Adding 4 to move distribution away from 0, 4 Std Dev should be far enough @@ -59,14 +62,17 @@ def test_fp32_div(self): def test_fp32_div_single_input(self): # Adding 4 to move distribution away from 0, 4 Std Dev should be far enough inputs = (torch.randn(1) + 4,) - ( - Tester(self.DivSingleInput(), inputs) - .export() - .check_count({"torch.ops.aten.div.Tensor": 1}) - .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) - .to_executorch() - .serialize() - .run_method_and_compare_outputs() - ) + for legacy_mode in (True, False): + tester = Tester(self.DivSingleInput(), inputs) + tester.export() + tester.check_count({"torch.ops.aten.div.Tensor": 1}) + if legacy_mode: + tester.to_edge() + tester.partition() + else: + tester.to_edge_transform_and_lower() + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"]) + tester.to_executorch() + tester.serialize() + tester.run_method_and_compare_outputs() diff --git a/backends/xnnpack/test/ops/test_tanh.py b/backends/xnnpack/test/ops/test_tanh.py new file mode 100644 index 00000000000..e7bac4541c9 --- /dev/null +++ b/backends/xnnpack/test/ops/test_tanh.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack.test.tester import Tester + + +class TestTanh(unittest.TestCase): + def setUp(self): + torch._dynamo.reset() + + class Tanh(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.tanh(x) + + def run_tanh_test(self, inputs): + ( + Tester(self.Tanh(), inputs) + .export() + .check_count({"torch.ops.aten.tanh.default": 1}) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"]) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp16_tanh(self): + inputs = (torch.randn(20).to(torch.float16),) + self.run_tanh_test(inputs) + + def test_fp32_tanh(self): + inputs = (torch.randn(20),) + self.run_tanh_test(inputs) diff --git a/backends/xnnpack/test/passes/__init__.py b/backends/xnnpack/test/passes/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backends/xnnpack/test/passes/test_activation_fusion.py b/backends/xnnpack/test/passes/test_activation_fusion.py index 6a1182dc7fb..5d65679948b 100644 --- a/backends/xnnpack/test/passes/test_activation_fusion.py +++ b/backends/xnnpack/test/passes/test_activation_fusion.py @@ -7,6 +7,7 @@ import unittest import torch +from executorch.backends.test.harness.stages import StageType from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass from executorch.backends.xnnpack.test.tester import RunPasses, Tester @@ -59,7 +60,7 @@ def _test_op_activation_case( .to_edge() .run_passes(self.PassStage) .check_not([activation_name]) - .get_artifact(Tester.stage_name(self.PassStage)) + .get_artifact(StageType.RUN_PASSES) ) for node in artifact.exported_program().module().graph.nodes: diff --git a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py index 2cede185773..cfc409b4596 100644 --- a/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py +++ b/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py @@ -335,3 +335,50 @@ def test_dq_conv2d_channels_last_tagged_reshape_pass(self) -> None: ) .run_method_and_compare_outputs() ) + + class ConvAddConvOutput(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 16, 3) + self.conv2 = torch.nn.Conv2d(16, 16, 3) + + def forward(self, x): + y = self.conv1(x) + z = torch.add(y, 1.0) + out1 = self.conv2(z) + out2 = z + return out1, out2 + + ConvAddConvOutputModule = ConvAddConvOutput() + + def test_conv_add_conv_output(self): + x = torch.randn(1, 3, 8, 8) + + self.run_tester(self.ConvAddConvOutput().eval(), (x,)) + + x_cl = x.to(memory_format=torch.channels_last) + self.run_tester(self.ConvAddConvOutput().eval(), (x_cl,)) + + class ThreeOutputsModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + self.conv2 = torch.nn.Conv2d(3, 3, 3) + self.linear = torch.nn.Linear(6, 6) + + def forward(self, x): + conv1_out = self.conv1(x) + conv2_out = self.conv2(x) + linear_out = self.linear(x) + + return linear_out, conv1_out, conv2_out + + ThreeOutputsModelModule = ThreeOutputsModel() + + def test_three_outputs_model(self): + x = torch.randn(1, 3, 6, 6) + + self.run_tester(self.ThreeOutputsModelModule.eval(), (x,)) + + x_cl = x.to(memory_format=torch.channels_last) + self.run_tester(self.ThreeOutputsModelModule.eval(), (x_cl,)) diff --git a/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py b/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py index 6fec7726835..2347122a180 100644 --- a/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py +++ b/backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py @@ -7,6 +7,7 @@ import unittest import torch +from executorch.backends.test.harness.stages import StageType from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( TagImplicitQDqPass, ) @@ -61,7 +62,7 @@ def test_tag_implicit_q_dq_test(self): .to_edge() .run_passes(self.PassStage) .run_method_and_compare_outputs() - .get_artifact(Tester.stage_name(self.PassStage)) + .get_artifact(StageType.RUN_PASSES) ) for node in artifact.exported_program().module().graph.nodes: diff --git a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py index 5053458613e..f97236bed7b 100644 --- a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py +++ b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py @@ -6,6 +6,8 @@ # pyre-unsafe +import unittest + from collections import Counter from typing import Dict, Tuple @@ -723,6 +725,7 @@ def test_save_load(self) -> None: instantiate_parametrized_tests(TestQuantizePT2E) +@unittest.skip("TODO: Reenable it after debug infrature finish update") class TestNumericDebugger(TestCase): def _extract_debug_handles(self, model) -> Dict[str, int]: debug_handle_map: Dict[str, int] = {} diff --git a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py index 0a317ad8822..84b1a932a5b 100644 --- a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py +++ b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py @@ -120,6 +120,116 @@ def test_conv1d_with_conv2d(self): node_list, ) + def test_q_tconv_and_conv2d(self): + class TConv2dConv2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.ConvTranspose2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + self.second = torch.nn.Conv2d( + in_channels=3, + out_channels=2, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + + def forward(self, x): + y = self.first(x) + return self.second(y) + + def example_inputs(self): + return (torch.randn(1, 1, 3, 3),) + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_operator_type( + torch.ops.aten.conv_transpose2d.input, quantization_config + ) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + ] + m = TConv2dConv2d() + self._test_quantizer( + m, + m.example_inputs(), + quantizer, + node_occurrence, + node_list, + is_debug_mode=True, + ) + + def test_q_conv2_and_tconv2d(self): + class TConv2dConv2d(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.ConvTranspose2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + self.second = torch.nn.Conv2d( + in_channels=3, + out_channels=2, + kernel_size=(3, 3), + padding=1, + bias=False, + ) + + def forward(self, x): + y = self.first(x) + return self.second(y) + + def example_inputs(self): + return (torch.randn(1, 1, 3, 3),) + + quantizer = XNNPACKQuantizer() + quantization_config = get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_operator_type(torch.ops.aten.conv2d.default, quantization_config) + node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.aten.conv_transpose2d.input, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + ] + m = TConv2dConv2d() + self._test_quantizer( + m, + m.example_inputs(), + quantizer, + node_occurrence, + node_list, + is_debug_mode=True, + ) + def test_linear(self): quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) diff --git a/backends/xnnpack/test/tester/TARGETS b/backends/xnnpack/test/tester/TARGETS index 231de970d7b..44925141cca 100644 --- a/backends/xnnpack/test/tester/TARGETS +++ b/backends/xnnpack/test/tester/TARGETS @@ -15,6 +15,7 @@ runtime.python_library( ], deps = [ "//caffe2:torch", + "//executorch/backends/test/harness:tester", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", "//executorch/backends/xnnpack/quantizer:xnnpack_quantizer_utils", diff --git a/backends/xnnpack/test/tester/__init__.py b/backends/xnnpack/test/tester/__init__.py index f92088c72e8..44933c43309 100644 --- a/backends/xnnpack/test/tester/__init__.py +++ b/backends/xnnpack/test/tester/__init__.py @@ -13,16 +13,18 @@ Serialize, Tester, ToEdge, + ToEdgeTransformAndLower, ToExecutorch, ) __all__ = [ - Tester, - Partition, - Quantize, Export, ToEdge, + Partition, + Quantize, RunPasses, - ToExecutorch, + ToEdgeTransformAndLower, + Tester, Serialize, + ToExecutorch, ] diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index dad0d5ad0e0..62eb504faa7 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -5,47 +5,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import copy +from typing import Any, Callable, List, Optional, Sequence, Tuple, Type -import logging -import random -import sys -from abc import ABC, abstractmethod -from collections import Counter, OrderedDict -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +import executorch +import executorch.backends.test.harness.stages as BaseStages import torch -from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( - DuplicateDynamicQuantChainPass, -) +from executorch.backends.test.harness import Tester as TesterBase +from executorch.backends.test.harness.stages import StageType from executorch.backends.xnnpack._passes import XNNPACKPassManager from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner -from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config -from executorch.exir import ( - EdgeCompileConfig, - EdgeProgramManager, - ExecutorchBackendConfig, - ExecutorchProgramManager, - to_edge, - to_edge_transform_and_lower, -) -from executorch.exir.backend.backend_api import validation_disabled -from executorch.exir.backend.partitioner import Partitioner -from executorch.exir.dim_order_utils import get_memory_format -from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass - -from executorch.exir.print_program import pretty_print, print_program -from torch.export import export_for_training - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) -try: - from executorch.extension.pybindings.portable_lib import ( # @manual - _load_for_executorch_from_buffer, - ) -except ImportError as e: - logger.warning(f"{e=}") - pass from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, @@ -54,101 +23,18 @@ from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( QuantizationConfig, ) -from executorch.exir.program._program import _transform +from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.partitioner import Partitioner from torch._export.pass_base import PassType -from torch.export import export, ExportedProgram -from torch.testing import FileCheck -from torch.utils._pytree import tree_flatten -from torchao.quantization.pt2e.quantize_pt2e import ( - convert_pt2e, - prepare_pt2e, - prepare_qat_pt2e, -) from torchao.quantization.pt2e.quantizer import Quantizer -class Stage(ABC): - """ - Interface for a Stage in the PT2.0 lowering pipeline - """ - - @abstractmethod - def run(self, artifact, inputs): - """ - Executes this stage, generates the 'artifact', for later stages. - """ - pass - - @property - @abstractmethod - def artifact(self): - """ - Returns the artifact generated by this stage. To be used by the next stage in the pipeline. - """ - pass - - @property - @abstractmethod - def graph_module(self): - """ - Return the artifact's graph module for this stage - """ - pass - - def run_artifact(self, inputs): - """ - Returns the output of calling the artifact generated by this stage with inputs - """ - if isinstance(self.artifact, ExportedProgram): - return self.artifact(*inputs) - else: - return self.artifact.exported_program().module()(*inputs) - - # Debug Tools for stages - def artifact_str(self): - """ - Return string printable artifact for this stage - """ - if isinstance(self.artifact, EdgeProgramManager): - return self.artifact.exported_program() - return self.artifact - - def stage_banner(self): - """ - Returns banner string for this stage - """ - return "#" * 36 + " " + str(self.__class__.__name__) + " " + "#" * 36 + "\n" - - def dump_artifact(self, path_to_dump: Optional[str]): - """ - Dumps string printable artifact to path. If path_to_dump, then it is printed to terminal - """ - if path_to_dump: - with open(path_to_dump, "a") as fp: - fp.write(str(self.stage_banner() + "\n")) - fp.write(str(self.artifact_str())) - else: - print(self.stage_banner() + "\n") - print(self.artifact_str()) - - -_stages_: Dict[str, Stage] = {} - - -def register_stage(stage: Stage): - """ - Register a Stage to be used in the Tester. - """ - assert isinstance(stage, type) - name = stage.__qualname__ - if name in _stages_: - raise RuntimeError(f"Duplicate stage in Tester, {name}") - _stages_[name] = stage - return stage +class Export(BaseStages.Export): + pass -@register_stage -class Quantize(Stage): +class Quantize(BaseStages.Quantize): def __init__( self, quantizer: Optional[Quantizer] = None, @@ -157,666 +43,88 @@ def __init__( calibration_samples: Optional[Sequence[Any]] = None, is_qat: Optional[bool] = False, ): - self.quantizer = quantizer or XNNPACKQuantizer() - self.quantization_config = ( - quantization_config or get_symmetric_quantization_config(is_qat=is_qat) - ) - self.calibrate = calibrate - self.calibration_samples = calibration_samples - - self.quantizer.set_global(self.quantization_config) - - self.converted_graph = None - self.is_qat = is_qat - - def run( - self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] - ) -> None: - assert inputs is not None - if self.is_qat: - artifact.train() - captured_graph = export_for_training(artifact, inputs, strict=True).module() - - assert isinstance(captured_graph, torch.fx.GraphModule) - - if self.is_qat: - prepared = prepare_qat_pt2e(captured_graph, self.quantizer) - else: - prepared = prepare_pt2e(captured_graph, self.quantizer) - - if self.calibrate: - # Calibrate prepared model to provide data to quantization observers. - if self.calibration_samples is not None: - for inp in self.calibration_samples: - prepared(*inp) - else: - prepared(*inputs) - - converted = convert_pt2e(prepared) - DuplicateDynamicQuantChainPass()(converted) - - self.converted_graph = converted - - @property - def artifact(self) -> torch.fx.GraphModule: - return self.converted_graph - - @property - def graph_module(self) -> str: - return self.converted_graph - - def run_artifact(self, inputs): - return self.converted_graph.forward(*inputs) - - -@register_stage -class Export(Stage): - def __init__(self, dynamic_shapes: Optional[Tuple[Any]] = None): - self.exported_program = None - self.dynamic_shapes = dynamic_shapes - - def run( - self, - artifact: torch.nn.Module, - inputs: Tuple[torch.Tensor], - ) -> None: - self.exported_program = export( - artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True + super().__init__( + quantizer=quantizer or XNNPACKQuantizer(), + quantization_config=( + quantization_config or get_symmetric_quantization_config(is_qat=is_qat) + ), + calibrate=calibrate, + calibration_samples=calibration_samples, + is_qat=is_qat, ) - @property - def artifact(self) -> ExportedProgram: - return self.exported_program - @property - def graph_module(self) -> str: - return self.exported_program.graph_module - - -@register_stage -class ToEdge(Stage): - def __init__(self, edge_compile_config: Optional[EdgeCompileConfig] = None): - self.edge_compile_conf = ( - edge_compile_config or get_xnnpack_edge_compile_config() - ) - self.edge_dialect_program = None - - def run(self, artifact: ExportedProgram, inputs=None) -> None: - self.edge_dialect_program = to_edge( - artifact, compile_config=self.edge_compile_conf - ) - - @property - def artifact(self) -> EdgeProgramManager: - return self.edge_dialect_program - - @property - def graph_module(self) -> str: - return self.edge_dialect_program.exported_program().graph_module - - -@register_stage -class RunPasses(Stage): +class RunPasses(BaseStages.RunPasses): def __init__( self, pass_list: Optional[List[Type[PassType]]] = None, pass_functions: Optional[List[Callable]] = None, ): - self.pass_list = pass_list - self.pass_functions = pass_functions - self.edge_or_aten_program = None - - def run( - self, artifact: Union[EdgeProgramManager, ExportedProgram], inputs=None - ) -> None: - if isinstance(artifact, EdgeProgramManager): - self.edge_or_aten_program = artifact - if self.pass_list: - pass_manager = XNNPACKPassManager( - artifact.exported_program(), self.pass_list - ) - self.edge_or_aten_program._edge_programs["forward"] = ( - pass_manager.transform() - ) - if self.pass_functions: - assert isinstance(self.pass_functions, list) - for pass_function in self.pass_functions: - self.edge_or_aten_program._edge_programs["forward"] = pass_function( - self.edge_or_aten_program.exported_program() - ) - else: - transformed_ep = artifact - if self.pass_list: - assert isinstance(self.pass_list, list) - for pass_ in self.pass_list: - transformed_ep = _transform(transformed_ep, pass_()) - - if self.pass_functions: - assert isinstance(self.pass_functions, list) - for pass_function in self.pass_functions: - transformed_ep = pass_function(transformed_ep) - - self.edge_or_aten_program = transformed_ep + super().__init__( + pass_manager_cls=XNNPACKPassManager, + pass_list=pass_list, + pass_functions=pass_functions, + ) - @property - def artifact(self) -> Union[EdgeProgramManager, ExportedProgram]: - return self.edge_or_aten_program - @property - def graph_module(self) -> str: - if isinstance(self.edge_or_aten_program, EdgeProgramManager): - return self.edge_or_aten_program.exported_program().graph_module - else: - return self.edge_or_aten_program.graph_module +class ToEdge(BaseStages.ToEdge): + def __init__(self, edge_compile_config: Optional[EdgeCompileConfig] = None): + super().__init__(edge_compile_config or get_xnnpack_edge_compile_config()) -@register_stage -class ToEdgeTransformAndLower(Stage): +class ToEdgeTransformAndLower(BaseStages.ToEdgeTransformAndLower): def __init__( self, partitioners: Optional[List[Partitioner]] = None, edge_compile_config: Optional[EdgeCompileConfig] = None, ): - self.partitioners = partitioners or [XnnpackPartitioner()] - self.edge_compile_conf = ( - edge_compile_config or get_xnnpack_edge_compile_config() - ) - self.edge_dialect_program = None - - def run(self, artifact: ExportedProgram, inputs=None) -> None: - self.edge_dialect_program = to_edge_transform_and_lower( - artifact, - compile_config=self.edge_compile_conf, - partitioner=self.partitioners, + super().__init__( + default_partitioner_cls=XnnpackPartitioner, + partitioners=partitioners, + edge_compile_config=edge_compile_config + or get_xnnpack_edge_compile_config(), ) - @property - def artifact(self) -> EdgeProgramManager: - return self.edge_dialect_program - - @property - def graph_module(self) -> str: - return self.edge_dialect_program.exported_program().graph_module - -@register_stage -class Partition(Stage): +class Partition(BaseStages.Partition): def __init__(self, partitioner: Optional[Partitioner] = None): - self.partitioner = partitioner or XnnpackPartitioner() - self.delegate_module = None - - def run(self, artifact: EdgeProgramManager, inputs=None): - with validation_disabled(): - self.delegate_module = artifact - self.delegate_module = self.delegate_module.to_backend(self.partitioner) - - @property - def artifact(self) -> EdgeProgramManager: - return self.delegate_module - - @property - def graph_module(self) -> str: - return self.delegate_module.exported_program().graph_module - - -@register_stage -class ToExecutorch(Stage): - def __init__( - self, - config: Optional[ExecutorchBackendConfig] = None, - ): - self.config = config or ExecutorchBackendConfig( - extract_delegate_segments=True, - sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), + super().__init__( + partitioner=partitioner or XnnpackPartitioner(), ) - self.executorch_program = None - - def run(self, artifact: EdgeProgramManager, inputs=None): - self.executorch_program = artifact.to_executorch(self.config) - @property - def artifact(self) -> ExecutorchProgramManager: - return self.executorch_program - @property - def graph_module(self) -> str: - return self.executorch_program().graph_module - - def dump_artifact(self, path_to_dump: Optional[str]): - """ - dump_artifact is overridden to dump the serialized program - """ - original_stdout = sys.stdout - - sys.stdout = open(path_to_dump, "a") if path_to_dump else sys.stdout - print(self.stage_banner() + "\n") - pretty_print(self.artifact._emitter_output.program) - print_program( - self.artifact._emitter_output.program, - show_meminfo=True, - mark_dynamic_shape_tensor=True, - ) - sys.stdout = original_stdout - - -@register_stage -class Serialize(Stage): - def __init__(self): - self.buffer = None - - def run(self, artifact: ExecutorchProgramManager, inputs=None) -> None: - self.buffer = artifact.buffer - - @property - def artifact(self) -> bytes: - return self.buffer +class Serialize(BaseStages.Serialize): + pass - @property - def graph_module(self) -> None: - return None - def run_artifact(self, inputs): - inputs_flattened, _ = tree_flatten(inputs) - executorch_module = _load_for_executorch_from_buffer(self.buffer) - executorch_output = copy.deepcopy( - executorch_module.run_method("forward", tuple(inputs_flattened)) - ) - return executorch_output - - def dump_artifact(self, path_to_dump: Optional[str]): - """ - dump_artifact is overridden to dump the serialized bytes into pte file - """ - if not path_to_dump: - raise RuntimeError("path_to_dump file not provided") - else: - with open(path_to_dump, "wb") as f: - f.write(self.artifact) +class ToExecutorch(BaseStages.ToExecutorch): + pass -class Tester: +class Tester(TesterBase): def __init__( self, module: torch.nn.Module, example_inputs: Tuple[torch.Tensor], dynamic_shapes: Optional[Tuple[Any]] = None, ): - module.eval() - - self.original_module = module - self.example_inputs = example_inputs - self.dynamic_shapes = dynamic_shapes - self.stages: Dict[str, Stage] = OrderedDict.fromkeys(list(_stages_.keys())) - self.pipeline = { - self.stage_name(Quantize): [self.stage_name(Export)], - self.stage_name(Export): [ - self.stage_name(RunPasses), - self.stage_name(ToEdge), - self.stage_name(ToEdgeTransformAndLower), - ], - self.stage_name(ToEdgeTransformAndLower): [ - self.stage_name(RunPasses), - self.stage_name(ToExecutorch), - ], - self.stage_name(ToEdge): [ - self.stage_name(Partition), - self.stage_name(RunPasses), - ], - self.stage_name(RunPasses): [ - self.stage_name(Partition), - self.stage_name(ToEdgeTransformAndLower), - ], - # TODO Make this Stage optional - self.stage_name(Partition): [self.stage_name(ToExecutorch)], - self.stage_name(ToExecutorch): [self.stage_name(Serialize)], - self.stage_name(Serialize): [], - } - assert all( - stage in self.pipeline for stage in self.stages - ), "Invalid Tester internal state!" - - # Current stage name - self.cur: str = "" - - # Reference output from eager mode - self.reference_output = None - - # Quantization scale from eager mode - self.quantization_scale: Optional[float] = None - - # Artifact output from stage - self.stage_output = None - - def generate_random_inputs(self): - # Get shapes of inputs - input_shapes = [] - if self.dynamic_shapes is None: - for tensor_arg in self.example_inputs: - assert isinstance(tensor_arg, torch.Tensor) - input_shapes.append(tensor_arg.shape) - else: - # Random shapes depending on dynamic shape constraint - dim_name_to_size = {} - for arg_idx in range(len(self.example_inputs)): - assert isinstance(self.example_inputs[arg_idx], torch.Tensor) - ex_shape = list(self.example_inputs[arg_idx].shape) - dynamic_dim_spec = self.dynamic_shapes[arg_idx] - for dim_idx, dim_spec in dynamic_dim_spec.items(): - assert dim_idx < len(ex_shape) - if isinstance(dim_spec, torch.export.dynamic_shapes._DerivedDim): - # derived dims are of the form {0: 2 * torch.export.Dim() // 2} - # The root contains the min/max of the export dim and fn contains - # the function to compute the derived dim. - dim_spec = dim_spec.root - fn = dim_spec.fn - elif isinstance(dim_spec, torch.export.dynamic_shapes._Dim): - # Not derived dim so fn is just itself - def fn(x): - return x - - else: - raise RuntimeError( - f"Expected Dynamic Dims to be of type _DerivedDim or _Dim but got {type(dim_spec)}" - ) - dim_name = dim_spec.__name__ - if dim_name not in dim_name_to_size: - upper_bound = min( - dim_spec.max, 1000 - ) # unbounded int max is too large - lower_bound = ( - dim_spec.min if dim_spec.min >= 2 else 1 - ) # 0/1 specialization means dim_spec.min can never be 1 - dim_name_to_size[dim_name] = fn( - random.randint(lower_bound, upper_bound) - ) - ex_shape[dim_idx] = dim_name_to_size[dim_spec.__name__] - input_shapes.append(torch.Size(ex_shape)) - # create random tensor inputs with the shapes given above: - random_inputs = [] - for arg_idx in range(len(self.example_inputs)): - memFormat = get_memory_format( - list(self.example_inputs[arg_idx].dim_order()) - ) - random_inputs.append( - torch.randn(input_shapes[arg_idx]) - .to(dtype=self.example_inputs[arg_idx].dtype) - .to(memory_format=memFormat) - ) - - yield tuple(random_inputs) - - @staticmethod - def stage_name(stage) -> str: - t = stage if isinstance(stage, type) else type(stage) - return t.__qualname__ - - def _pre(self, stage): - name: str = self.stage_name(stage) - assert isinstance(name, str) and name in self.stages and not self.stages[name] - - last_artifact = self.original_module - if self.cur: - assert self.cur in self.pipeline, f"Invalid state: {self.cur}" - allowed_next_stages = self.pipeline[self.cur] - assert name in allowed_next_stages, f"Invalid next stage: {name}" - last_artifact = self.get_artifact() - self.cur = name - return last_artifact - - def _post(self, stage): - name = self.stage_name(stage) - assert name in self.stages - self.stages[name] = stage - - def _run_stage(self, stage_instance, inputs=None): - assert isinstance(stage_instance, Stage) - prev_stage_artifact = self._pre(stage_instance) - stage_instance.run(prev_stage_artifact, inputs=inputs) - self._post(stage_instance) - return self - - # Stages - def quantize(self, quantize_stage: Optional[Quantize] = None): - return self._run_stage(quantize_stage or Quantize(), self.example_inputs) - - def export(self, export_stage: Optional[Export] = None): - return self._run_stage( - export_stage or Export(dynamic_shapes=self.dynamic_shapes), - self.example_inputs, + # Specialize for XNNPACK + stage_classes = ( + executorch.backends.test.harness.Tester.default_stage_classes() + | { + StageType.EXPORT: Export, + StageType.PARTITION: Partition, + StageType.QUANTIZE: Quantize, + StageType.RUN_PASSES: RunPasses, + StageType.TO_EDGE: ToEdge, + StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower, + StageType.SERIALIZE: Serialize, + } ) - def to_edge(self, to_edge_stage: Optional[ToEdge] = None): - if not to_edge_stage: - to_edge_stage = ToEdge() - res = self._run_stage(to_edge_stage) - return res - - def to_edge_transform_and_lower( - self, to_edge_and_transform_stage: Optional[ToEdgeTransformAndLower] = None - ): - return self._run_stage(to_edge_and_transform_stage or ToEdgeTransformAndLower()) - - def run_passes(self, run_passes_stage: Optional[RunPasses] = None): - return self._run_stage(run_passes_stage or RunPasses()) - - def partition(self, partition_stage: Optional[Partition] = None): - return self._run_stage(partition_stage or Partition()) - - def to_executorch(self, to_executorch_stage: Optional[ToExecutorch] = None): - return self._run_stage(to_executorch_stage or ToExecutorch()) - - def serialize(self, serialize_stage: Optional[Serialize] = None): - return self._run_stage(serialize_stage or Serialize()) - - # Util functions - def dump_artifact(self, path: Optional[str] = None, stage: Optional[str] = None): - stage = stage or self.cur - self.stages[stage].dump_artifact(path) - return self - - def get_artifact(self, stage: Optional[str] = None): - stage = stage or self.cur - return self.stages[stage].artifact - - def check(self, input: List[str]): - for key in input: - FileCheck().check(key).run(self.stages[self.cur].graph_module.code) - return self - - def check_not(self, input: List[str]): - for key in input: - FileCheck().check_not(key).run(self.stages[self.cur].graph_module.code) - return self - - def check_count(self, input: Dict[Any, int]): - # TODO target checks similar to checkGraphModuleNodes() - for key, count in input.items(): - FileCheck().check_count(key, count, exactly=True).run( - self.stages[self.cur].graph_module.code - ) - return self - - def check_node_count(self, input: Dict[Any, int]): - # Count the occurances of each target in the graph. - target_ops = [ - node.target - for node in self.stages[self.cur].graph_module.graph.nodes - if node.op == "call_function" - ] - op_counts = Counter(target_ops) - - for key, count in input.items(): - if count != op_counts[key]: - print(f"Nodes: {op_counts}") - raise AssertionError( - f"Expected {count} {key} nodes but found {op_counts[key]}." - ) - - return self - - def visualize( - self, reuse_server: bool = True, stage: Optional[str] = None, **kwargs - ): - # import here to avoid importing model_explorer when it is not needed which is most of the time. - from executorch.devtools.visualization import visualize - - visualize(self.get_artifact(stage), reuse_server=reuse_server, **kwargs) - return self - - def run_method_and_compare_outputs( - self, - stage: Optional[str] = None, - inputs: Optional[Tuple[torch.Tensor]] = None, - num_runs=1, - atol=1e-03, - rtol=1e-03, - qtol=0, - ): - number_of_runs = 1 if inputs is not None else num_runs - reference_stage = self.stages[self.stage_name(Export)] - - stage = stage or self.cur - - print(f"Comparing Stage {stage} with Stage {reference_stage}") - for run_iteration in range(number_of_runs): - inputs_to_run = inputs if inputs else next(self.generate_random_inputs()) - input_shapes = [generated_input.shape for generated_input in inputs_to_run] - print(f"Run {run_iteration} with input shapes: {input_shapes}") - - # Reference output (and quantization scale) - ( - reference_output, - quantization_scale, - ) = self._calculate_reference_output( - reference_stage.artifact, inputs_to_run - ) - - # Output from running artifact at stage - stage_output = self.stages[stage].run_artifact(inputs_to_run) - self._compare_outputs( - reference_output, stage_output, quantization_scale, atol, rtol, qtol - ) - - return self - - @staticmethod - def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): - """ - Helper testing function that asserts that the model output and the reference output - are equal with some tolerance. Due to numerical differences between eager mode and - the XNNPACK's backend, we relax the detal such that absolute tolerance is 1e-3. and - relative tolerance is 1e-3. In the event that the computation was quantized, we - further relax the tolerance to one quantized step (equal to the quantization scale). - This allows the quantized value to differ by 1 between the reference and model output. - """ - - assert len(model_output) == len(ref_output) - - for i in range(len(model_output)): - model = model_output[i] - ref = ref_output[i] - assert ( - ref.shape == model.shape - ), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}" - if model.dtype == torch.bool: - assert torch.equal(model, ref), ( - f"Output {i} (bool tensor) does not match reference output.\n" - f"\tShape: {model.shape}\n" - f"\tMismatched count: {(model != ref).sum().item()} / {model.numel()}\n" - ) - else: - assert torch.allclose( - model, - ref, - atol=atol, - rtol=rtol, - ), ( - f"Output {i} does not match reference output.\n" - f"\tGiven atol: {atol}, rtol: {rtol}.\n" - f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" - f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n" - f"\t-- Model vs. Reference --\n" - f"\t Numel: {model.numel()}, {ref.numel()}\n" - f"\tMedian: {model.median()}, {ref.median()}\n" - f"\t Mean: {model.mean()}, {ref.mean()}\n" - f"\t Max: {model.max()}, {ref.max()}\n" - f"\t Min: {model.min()}, {ref.min()}\n" - ) - - @staticmethod - def _compare_outputs( - reference_output, - stage_output, - quantization_scale=None, - atol=1e-03, - rtol=1e-03, - qtol=0, - ): - """ - Compares the original of the original nn module with the output of the generated artifact. - This requres calling run_method before calling compare_outputs. As that runs the generated - artifact on the sample inputs and sets the stage output to be compared against the reference. - """ - # Wrap both outputs as tuple, since executor output is always a tuple even if single tensor - if isinstance(reference_output, torch.Tensor): - reference_output = (reference_output,) - if isinstance(stage_output, torch.Tensor): - stage_output = (stage_output,) - - # If a qtol is provided and we found an dequantization node prior to the output, relax the - # atol by qtol quant units. - if quantization_scale is not None: - atol += quantization_scale * qtol - - Tester._assert_outputs_equal( - stage_output, - reference_output, - atol=atol, - rtol=rtol, + super().__init__( + module=module, + stage_classes=stage_classes, + example_inputs=example_inputs, + dynamic_shapes=dynamic_shapes, ) - - @staticmethod - def _calculate_reference_output( - program: ExportedProgram, inputs - ) -> Tuple[torch.Tensor, Optional[float]]: - """ - Execute the reference program and return the output. If the output comes from a dequantize node, - return the quantization scale as well. - """ - - # Locate the output node. - output_node = None - for node in program.graph.nodes: - if node.op == "output": - output_node = node - break - assert output_node is not None - - # Look for a dequantization node in the output node args. Returned values are found in the first - # argument of the output node. - dequant_node = None - for arg_node in output_node.args[0]: - if ( - arg_node.op == "call_function" - and arg_node.target - == torch.ops.quantized_decomposed.dequantize_per_tensor.default - ): - dequant_node = arg_node - break - - scale = None - if dequant_node is not None: - original_target = dequant_node.target - - # Replace the dequant node with shim to intercept the quantization parameters. - # It will be invoked when we evaluate the program to find the reference outputs. - def dequant_shim(*args): - nonlocal scale - scale = args[1] - result = original_target(*args) - return result - - dequant_node.target = dequant_shim - - output = program.module()(*inputs) - return output, scale diff --git a/backends/xnnpack/utils/utils.py b/backends/xnnpack/utils/utils.py index b23fd444117..a8f3178f98f 100644 --- a/backends/xnnpack/utils/utils.py +++ b/backends/xnnpack/utils/utils.py @@ -25,6 +25,7 @@ is_lifted_tensor_constant, is_param, ) +from torchao.quantization.pt2e.utils import _is_conv_node, _is_conv_transpose_node ### XNNPACK Capture ### @@ -160,6 +161,36 @@ def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]: return source_fn[1] +def get_groups_from_conv(conv_node: torch.fx.Node) -> int: + if _is_conv_node(conv_node): + in_node = cast(torch.fx.Node, conv_node.args[0]) + weight_node = cast(torch.fx.Node, conv_node.args[1]) + # groups isn't given to us in the training graph so we deduce it from the weight shape + # and the input shape + + # input shape is (N, C_in, H_in, W_in) + in_channels = in_node.meta["val"].shape[1] + + # weight shape is (C_out, C_in/groups, kernel_size[0], kernel_size[1]) + in_groups = weight_node.meta["val"].shape[1] + + return in_channels // in_groups + elif _is_conv_transpose_node(conv_node): + weight_node = cast(torch.fx.Node, conv_node.args[1]) + # groups isn't given to us in the training graph so we deduce it from the weight shape + # and the output shape + + # weight shape is (C_in, C_out/groups, kernel_size[0], kernel_size[1]) + out_groups = weight_node.meta["val"].shape[1] + + # output shape is (N, C_out, H_out, W_out) + out_channels = conv_node.meta["val"].shape[1] + + return out_channels // out_groups + + raise RuntimeError(f"expected {conv_node} to be a conv or conv_transpose node") + + def is_depthwise_conv( kernel_shape: Tuple[int, ...], groups: int = 1, is_transpose: bool = False ) -> bool: diff --git a/devtools/inspector/TARGETS b/devtools/inspector/TARGETS index d8d6c20fb20..0712bdf1f9a 100644 --- a/devtools/inspector/TARGETS +++ b/devtools/inspector/TARGETS @@ -19,6 +19,7 @@ python_library( "//executorch/devtools/etrecord:etrecord", "//executorch/exir:lib", "//executorch/devtools/inspector:intermediate_output_capturer", + "//executorch/devtools/inspector/numerical_comparator:lib", ], ) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 199a740737a..dfff3d0818e 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -55,6 +55,7 @@ inflate_runtime_output, is_debug_output, is_inference_output_equal, + map_runtime_aot_intermediate_outputs, ProgramOutput, RESERVED_FRAMEWORK_EVENT_NAMES, TimeScale, @@ -63,6 +64,10 @@ from executorch.devtools.inspector._intermediate_output_capturer import ( IntermediateOutputCapturer, ) +from executorch.devtools.inspector.numerical_comparator import ( + L1Comparator, + MSEComparator, +) from executorch.exir import ExportedProgram @@ -1337,3 +1342,50 @@ def get_exported_program( if graph is None else self._etrecord.graph_map.get(graph) ) + + def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame: + """ + Compares logged intermediate outputs from the exported graph (in ETRecord) + with runtime outputs (in ETDump) using a user-specific numerical comparator. + + Args: + distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR". + + Returns: + pd.DataFrame: A DataFrame listing corresponding operator outputs from + both stages and their computed numerical gaps. + """ + if self._aot_intermediate_outputs is None: + raise ValueError( + "The aot intermediate outputs is required but not populated." + ) + mapping = map_runtime_aot_intermediate_outputs( + self._aot_intermediate_outputs, self._get_runtime_intermediate_outputs() + ) + metric = distance.strip().upper() + if metric == "MSE": + comparator = MSEComparator() + elif metric == "L1": + comparator = L1Comparator() + else: + raise ValueError(f"Unsupported distance metric {distance!r}") + + rows = [] + for (aot_debug_handle, aot_intermediate_output), ( + runtime_debug_handle, + runtime_intermediate_output, + ) in mapping.items(): + if aot_intermediate_output is None or runtime_intermediate_output is None: + continue + rows.append( + { + "aot_debug_handle": aot_debug_handle, + "aot_intermediate_output": aot_intermediate_output, + "runtime_debug_handle": runtime_debug_handle, + "runtime_intermediate_output": runtime_intermediate_output, + "gap": comparator.compare( + aot_intermediate_output, runtime_intermediate_output + ), + } + ) + return pd.DataFrame(rows) diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 67fcc807752..21d627d4eba 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -8,6 +8,7 @@ import math import sys +from collections.abc import Sequence from dataclasses import dataclass from enum import Enum from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union @@ -676,17 +677,56 @@ def map_runtime_aot_intermediate_outputs( # Map only if both AOT and runtime data are present. if len(aot_list) != 0 and len(runtime_list) != 0: # Combine aot debug handles into a single key - aot_combined_debug_handle, aot_output = ( + aot_combined_debug_handle, aot_intermediate_output = ( _combine_overlapped_intermediate_outputs(aot_list) ) # Combine runtime debug handles into a single key - runtime_combined_debug_handle, runtime_output = ( + runtime_combined_debug_handle, runtime_intermediate_output = ( _combine_overlapped_intermediate_outputs(runtime_list) ) + # List can't be used as a key, so convert to tuple + if isinstance(aot_intermediate_output, list): + aot_intermediate_output = tuple(aot_intermediate_output) + # runtime follow the same format as aot, so it's safe to convert to tuple + if isinstance(runtime_intermediate_output, list): + runtime_intermediate_output = tuple(runtime_intermediate_output) # Create a mapping between runtime and aot - aot_runtime_mapping[(aot_combined_debug_handle, aot_output)] = ( + aot_runtime_mapping[ + (aot_combined_debug_handle, aot_intermediate_output) + ] = ( runtime_combined_debug_handle, - runtime_output, + runtime_intermediate_output, ) return aot_runtime_mapping + + +def convert_to_float_tensor(input_data: Any) -> torch.Tensor: + """ + Convert input_data into a torch.Tensor on CPU with dtype torch.float64. + This function handles the following types of input: + - Scalar (int or float): Converts to a tensor with a single element. + - Tensor: Converts to a float64 tensor on CPU. + - Sequence of Tensors: Stacks the tensors into a single float64 tensor on CPU. + The resulting tensor is detached, moved to CPU, and cast to torch.float64. + Parameters: + input_data (Any): The input data to be converted to a tensor. It can be a scalar, + a tensor, or a list of tensors. + Returns: + torch.Tensor: A tensor on CPU with dtype torch.float64. + Raises: + ValueError: If the input_data cannot be converted to a tensor. + """ + try: + # Check if the input is a Sequence of tensors + if isinstance(input_data, Sequence): + input_tensor = torch.stack([convert_to_float_tensor(a) for a in input_data]) + # Try to convert the input to a tensor + else: + input_tensor = torch.as_tensor(input_data, dtype=torch.float64) + except Exception as e: + raise ValueError( + f"Cannot convert value of type {type(input_data)} to a tensor: {e}" + ) + input_tensor = input_tensor.detach().cpu().double() + return input_tensor diff --git a/devtools/inspector/numerical_comparator/TARGETS b/devtools/inspector/numerical_comparator/TARGETS new file mode 100644 index 00000000000..1c0fc8abb85 --- /dev/null +++ b/devtools/inspector/numerical_comparator/TARGETS @@ -0,0 +1,37 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +oncall("executorch") + + +python_library( + name = "numerical_comparator_base", + srcs = ["numerical_comparator_base.py"], + deps = [], +) + +python_library( + name = "l1_numerical_comparator", + srcs = ["l1_numerical_comparator.py"], + deps = [ + "//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base", + "//executorch/devtools/inspector:inspector_utils", + ], +) + +python_library( + name = "mse_numerical_comparator", + srcs = ["mse_numerical_comparator.py"], + deps = [ + "//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base", + "//executorch/devtools/inspector:inspector_utils", + ], +) + +python_library( + name = "lib", + srcs = ["__init__.py"], + deps = [ + ":l1_numerical_comparator", + ":mse_numerical_comparator", + ], +) diff --git a/devtools/inspector/numerical_comparator/__init__.py b/devtools/inspector/numerical_comparator/__init__.py new file mode 100644 index 00000000000..6540b8c25e1 --- /dev/null +++ b/devtools/inspector/numerical_comparator/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from executorch.devtools.inspector.numerical_comparator.l1_numerical_comparator import ( + L1Comparator, +) + +from executorch.devtools.inspector.numerical_comparator.mse_numerical_comparator import ( + MSEComparator, +) + + +__all__ = ["L1Comparator", "MSEComparator"] diff --git a/devtools/inspector/numerical_comparator/inspector_numerical_comparator_base.py b/devtools/inspector/numerical_comparator/inspector_numerical_comparator_base.py new file mode 100644 index 00000000000..b6dac7e1970 --- /dev/null +++ b/devtools/inspector/numerical_comparator/inspector_numerical_comparator_base.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + + +from abc import ABC, abstractmethod +from typing import Any + + +class InspectorNumericalComparatorBase(ABC): + @abstractmethod + def compare(self, a: Any, b: Any) -> float: + """Compare two intermediate output and return a result. + + This method should be overridden by subclasses to provide custom comparison logic. + + Args: + a: The first intermediate output to compare. + b: The second intermediate output to compare. + + Returns: + A numerical result indicating the comparison outcome. + """ + pass diff --git a/devtools/inspector/numerical_comparator/l1_numerical_comparator.py b/devtools/inspector/numerical_comparator/l1_numerical_comparator.py new file mode 100644 index 00000000000..277cfb63cdc --- /dev/null +++ b/devtools/inspector/numerical_comparator/l1_numerical_comparator.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +import torch +from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor +from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import ( + NumericalComparatorBase, +) + + +class L1Comparator(NumericalComparatorBase): + def compare(self, a: Any, b: Any) -> float: + """Sum up all these element-wise absolute differences between two tensors.""" + + t_a = convert_to_float_tensor(a) + t_b = convert_to_float_tensor(b) + if torch.isnan(t_a).any() or torch.isnan(t_b).any(): + t_a = torch.nan_to_num(t_a) + t_b = torch.nan_to_num(t_b) + + try: + res = torch.abs(t_a - t_b).sum().item() + except Exception as e: + raise ValueError(f"Error computing L1 difference between tensors: {str(e)}") + return res diff --git a/devtools/inspector/numerical_comparator/mse_numerical_comparator.py b/devtools/inspector/numerical_comparator/mse_numerical_comparator.py new file mode 100644 index 00000000000..cb27a44fa22 --- /dev/null +++ b/devtools/inspector/numerical_comparator/mse_numerical_comparator.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +import torch +from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor +from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import ( + NumericalComparatorBase, +) + + +class MSEComparator(NumericalComparatorBase): + def compare(self, a: Any, b: Any) -> float: + """Compare mean squared difference between two outputs.""" + + t_a = convert_to_float_tensor(a) + t_b = convert_to_float_tensor(b) + if torch.isnan(t_a).any() or torch.isnan(t_b).any(): + t_a = torch.nan_to_num(t_a) + t_b = torch.nan_to_num(t_b) + + try: + res = float(torch.mean(torch.square(t_a - t_b))) + except Exception as e: + raise ValueError( + f"Error computing MSE difference between tensors: {str(e)}" + ) + return res diff --git a/devtools/inspector/numerical_comparator/numerical_comparator_base.py b/devtools/inspector/numerical_comparator/numerical_comparator_base.py new file mode 100644 index 00000000000..db498980e1f --- /dev/null +++ b/devtools/inspector/numerical_comparator/numerical_comparator_base.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from abc import ABC, abstractmethod +from typing import Any + + +class NumericalComparatorBase(ABC): + @abstractmethod + def compare(self, a: Any, b: Any) -> float: + """Compare two intermediate output and return a result. + + This method should be overridden by subclasses to provide custom comparison logic. + + Args: + a: The first intermediate output to compare. + b: The second intermediate output to compare. + + Returns: + A numerical result indicating the comparison outcome. + """ + pass diff --git a/devtools/inspector/tests/TARGETS b/devtools/inspector/tests/TARGETS index b5fbeda215b..e4bc78775d5 100644 --- a/devtools/inspector/tests/TARGETS +++ b/devtools/inspector/tests/TARGETS @@ -54,6 +54,22 @@ python_unittest( ], ) +python_unittest( + name = "l1_comparator_test", + srcs = ["l1_comparator_test.py"], + deps = [ + "//executorch/devtools/inspector/numerical_comparator:lib", + ], +) + +python_unittest( + name = "mse_comparator_test", + srcs = ["mse_comparator_test.py"], + deps = [ + "//executorch/devtools/inspector/numerical_comparator:lib", + ], +) + python_library( name = "inspector_test_utils", srcs = [ diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index b96a694b581..1460dbd46a2 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -17,6 +17,8 @@ from unittest.mock import patch +import pandas as pd + import torch import torch.fx @@ -578,6 +580,75 @@ def test_get_runtime_intermediate_outputs(self): self.assertIn((key,), runtime_outputs) self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE) + def test_calculate_numeric_gap(self): + # Create a context manager to patch functions called by Inspector.__init__ + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + # Call the constructor of Inspector + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + (1,): torch.tensor([4.0, 5.0, 6.0]), + } + + runtime_intermediate_outputs = { + (0,): torch.tensor([2.0, 1.0, 4.0]), + (1,): torch.tensor([3.0, 6.0, 5.0]), + } + + inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs + inspector_instance._get_runtime_intermediate_outputs = ( + lambda: runtime_intermediate_outputs + ) + + df = inspector_instance.calculate_numeric_gap(distance="L1") + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 2) + cols = set(df.columns) + expected_cols = { + "aot_debug_handle", + "aot_intermediate_output", + "runtime_debug_handle", + "runtime_intermediate_output", + "gap", + } + self.assertEqual(cols, expected_cols) + founded_aot_debug_handle = set(df["aot_debug_handle"]) + self.assertEqual( + founded_aot_debug_handle, set(aot_intermediate_outputs.keys()) + ) + for _, row in df.iterrows(): + aot_debuh_handle = row["aot_debug_handle"] + # aot_intermediate_output should equal aot_intermediate_outputs[h] + self.assertTrue( + torch.allclose( + row["aot_intermediate_output"], + aot_intermediate_outputs[aot_debuh_handle], + ) + ) + # runtime_debug_hanlde equals aot_debug_handle at this case + self.assertEqual(row["runtime_debug_handle"], aot_debuh_handle) + # runtime_intermediate_output should equal runtime_intermediate_outputs[h] + self.assertTrue( + torch.allclose( + row["runtime_intermediate_output"], + runtime_intermediate_outputs[aot_debuh_handle], + ) + ) + # gap should equal 3.0 + self.assertEqual(row["gap"], 3.0) + def _gen_random_float_list(self) -> List[float]: return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)] diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 38ed2c29ea2..8148d2c36f0 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -29,6 +29,7 @@ calculate_mse, calculate_snr, calculate_time_scale_factor, + convert_to_float_tensor, create_debug_handle_to_op_node_mapping, EDGE_DIALECT_GRAPH_KEY, find_populated_event, @@ -317,6 +318,52 @@ def test_map_runtime_aot_intermediate_outputs_complex_chain(self): expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)} self.assertEqual(actual, expected) + def test_convert_input_to_tensor_convertible_inputs(self): + # Scalar -> tensor + actual_output1 = convert_to_float_tensor(5) + self.assertIsInstance(actual_output1, torch.Tensor) + self.assertEqual(actual_output1.dtype, torch.float64) + self.assertEqual(tuple(actual_output1.shape), ()) + self.assertTrue( + torch.allclose(actual_output1, torch.tensor([5.0], dtype=torch.float64)) + ) + self.assertEqual(actual_output1.device.type, "cpu") + + # Tensor of ints -> float32 CPU + t_int = torch.tensor([4, 5, 6], dtype=torch.int32) + actual_output2 = convert_to_float_tensor(t_int) + self.assertIsInstance(actual_output2, torch.Tensor) + self.assertEqual(actual_output2.dtype, torch.float64) + self.assertTrue( + torch.allclose( + actual_output2, torch.tensor([4.0, 5.0, 6.0], dtype=torch.float64) + ) + ) + self.assertEqual(actual_output2.device.type, "cpu") + + # List of tensors -> stacked tensor float32 CPU + t_list = [torch.tensor([1, 2]), torch.tensor([2, 3]), torch.tensor([3, 4])] + actual_output3 = convert_to_float_tensor(t_list) + self.assertIsInstance(actual_output3, torch.Tensor) + self.assertEqual(actual_output3.dtype, torch.float64) + self.assertEqual(tuple(actual_output3.shape), (3, 2)) + self.assertTrue( + torch.allclose( + actual_output3, + torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], dtype=torch.float64), + ) + ) + self.assertEqual(actual_output3.device.type, "cpu") + + def test_convert_input_to_tensor_non_convertible_raises(self): + class X: + pass + + with self.assertRaises(ValueError) as cm: + convert_to_float_tensor(X()) + msg = str(cm.exception) + self.assertIn("Cannot convert value of type", msg) + def gen_mock_operator_graph_with_expected_map() -> ( Tuple[OperatorGraph, Dict[int, OperatorNode]] diff --git a/devtools/inspector/tests/l1_comparator_test.py b/devtools/inspector/tests/l1_comparator_test.py new file mode 100644 index 00000000000..9a14a410311 --- /dev/null +++ b/devtools/inspector/tests/l1_comparator_test.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from executorch.devtools.inspector.numerical_comparator import L1Comparator + + +class TestL1Comparator(unittest.TestCase): + l1_comparator = L1Comparator() + + def test_identical_tensors(self): + a = torch.tensor([[1, 2], [3, 4]]) + b = torch.tensor([[1, 2], [3, 4]]) + expected = 0.0 + result = self.l1_comparator.compare(a, b) + self.assertAlmostEqual(result, expected) + + def test_scalar(self): + a = 1 + b = 2 + expected = 1.0 + result = self.l1_comparator.compare(a, b) + self.assertAlmostEqual(result, expected) + + def test_with_nans_replaced_with_zero(self): + a = torch.tensor([3, 2, -1, float("nan")]) + b = torch.tensor([float("nan"), 0, -3, 1]) + expected = 8.0 + result = self.l1_comparator.compare(a, b) + self.assertAlmostEqual(result, expected) + + def test_shape_mismatch_raises_exception(self): + a = torch.tensor([0, 2, -1]) + b = torch.tensor([1, 0, -3, 4]) + with self.assertRaises(ValueError): + self.l1_comparator.compare(a, b) + + def test_2D_tensors(self): + a = torch.tensor([[4, 9], [6, 4]]) + b = torch.tensor([[1, 2], [3, 5]]) + expected = 14.0 + result = self.l1_comparator.compare(a, b) + self.assertAlmostEqual(result, expected) + + def test_list_of_tensors(self): + a = [torch.tensor([2, 4]), torch.tensor([5, 2])] + b = [torch.tensor([1, 2]), torch.tensor([3, 5])] + expected = 8.0 + result = self.l1_comparator.compare(a, b) + self.assertAlmostEqual(result, expected) diff --git a/devtools/inspector/tests/mse_comparator_test.py b/devtools/inspector/tests/mse_comparator_test.py new file mode 100644 index 00000000000..ee6b90dea1c --- /dev/null +++ b/devtools/inspector/tests/mse_comparator_test.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from executorch.devtools.inspector.numerical_comparator import MSEComparator + + +class TestMSEComparator(unittest.TestCase): + mse_comparator = MSEComparator() + + def test_identical_tensors(self): + a = torch.tensor([[10, 4], [3, 4]]) + b = torch.tensor([[10, 4], [3, 4]]) + expected = 0.0 + result = self.mse_comparator.compare(a, b) + self.assertAlmostEqual(result, expected) + + def test_scalar(self): + a = 10 + b = 2 + expected = 64.0 + result = self.mse_comparator.compare(a, b) + self.assertAlmostEqual(result, expected) + + def test_with_nans_replaced_with_zero(self): + a = torch.tensor([3, 1, -3, float("nan")]) + b = torch.tensor([float("nan"), 0, -3, 2]) + expected = (9.0 + 1.0 + 0.0 + 4.0) / 4.0 + result = self.mse_comparator.compare(a, b) + self.assertAlmostEqual(result, expected) + + def test_shape_mismatch_raises_exception(self): + a = torch.tensor([0, 2, -1]) + b = torch.tensor([1, 1, -3, 4]) + with self.assertRaises(ValueError): + self.mse_comparator.compare(a, b) + + def test_2D_tensors(self): + a = torch.tensor([[4, 9], [6, 4]]) + b = torch.tensor([[1, 2], [3, 10]]) + expected = (9.0 + 49.0 + 9.0 + 36.0) / 4.0 + result = self.mse_comparator.compare(a, b) + self.assertAlmostEqual(result, expected) + + def test_list_of_tensors(self): + a = [torch.tensor([2, 4]), torch.tensor([15, 2])] + b = [torch.tensor([1, 2]), torch.tensor([9, 5])] + expected = (1.0 + 4.0 + 36.0 + 9.0) / 4.0 + result = self.mse_comparator.compare(a, b) + self.assertAlmostEqual(result, expected) diff --git a/docs/source/backends-mediatek.md b/docs/source/backends-mediatek.md index 7200f24bf98..a562cea13bd 100644 --- a/docs/source/backends-mediatek.md +++ b/docs/source/backends-mediatek.md @@ -1,95 +1,79 @@ # MediaTek Backend -MediaTek backend empowers ExecuTorch to speed up PyTorch models on edge devices that equips with MediaTek Neuron Processing Unit (NPU). This document offers a step-by-step guide to set up the build environment for the MediaTek ExecuTorch libraries. - -::::{grid} 2 -:::{grid-item-card} What you will learn in this tutorial: -:class-card: card-prerequisites -* How to export and lower a PyTorch model ahead of time with ExecuTorch for MediaTek devices. -* How to build MediaTek backend and examples. -* How to deploy the exported models on device with ExecuTorch runtime. -::: -:::{grid-item-card} Tutorials we recommend you complete before this: -:class-card: card-prerequisites -* [Introduction to ExecuTorch](intro-how-it-works.md) -* [Getting Started](getting-started.md) -* [Building ExecuTorch with CMake](using-executorch-building-from-source.md) -::: -:::: - - -## Prerequisites (Hardware and Software) - -### Host OS -- Linux operating system - -### Supported Chips: -- MediaTek Dimensity 9300 (D9300) -- MediaTek Dimensity 9400 (D9400) +The MediaTek backend enables acceleration of PyTorch models on edge devices with MediaTek Neuron Processing Units (NPUs). This backend provides tools for exporting, building, and deploying models to leverage MediaTek hardware. -### Software: +## Features -- [NeuroPilot Express SDK](https://neuropilot.mediatek.com/resources/public/npexpress/en/docs/npexpress) is a lightweight SDK for deploying AI applications on MediaTek SOC devices. +- Acceleration of PyTorch models on MediaTek NPUs +- Tools for model export and lowering +- Example scripts for model deployment and execution -## Setting up your developer environment +## Target Requirements -Follow the steps below to setup your build environment: +- **Hardware:** MediaTek Dimensity 9300 (D9300), Dimensity 9400 (D9400) +- **Host OS:** Linux +- **SDK:** [NeuroPilot Express SDK](https://neuropilot.mediatek.com/resources/public/npexpress/en/docs/npexpress) -1. **Setup ExecuTorch Environment**: Refer to the [Getting Started](getting-started.md) guide for detailed instructions on setting up the ExecuTorch environment. +## Development Requirements -2. **Setup MediaTek Backend Environment** - ```bash - pip3 install -r requirements.txt - ``` -- Install the two .whl downloaded from NeuroPilot Portal - ```bash - pip3 install mtk_neuron-8.2.19-py3-none-linux_x86_64.whl - pip3 install mtk_converter-8.13.0+public-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - ``` -- Set evironment variables for building backend - ```bash - export NEURON_BUFFER_ALLOCATOR_LIB= - ``` -Additionally, make sure to copy `NeuronAdapter.h` to the following directory: `backends/mediatek/runtime/include/api/`. +- Linux operating system +- Python dependencies: + ```bash + pip3 install -r requirements.txt + ``` +- NeuroPilot SDK Python wheels (download from [NeuroPilot Express SDK](https://neuropilot.mediatek.com/resources/public/npexpress/en/docs/npexpress)): + ```bash + pip3 install mtk_neuron-8.2.19-py3-none-linux_x86_64.whl + pip3 install mtk_converter-8.13.0+public-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + ``` -## Build +## Using the MediaTek Backend -### Ahead of time: +### Exporting and Lowering a Model -**Exporting a PyTorch Model for MediaTek Backend**: -1. Lower and export the `.pte` file for on-device execution. The export script samples are povided under `example/mediatek/`. For example, the following commnad exports the `.pte` using the scripts provided. +To export and lower a model for the MediaTek backend, use the provided shell script: ```bash cd executorch - ./examples/mediatek/shell_scripts/export_oss.sh mobilenetv3 ``` +The exported `.pte` file is saved in a directory named after the model. -2. Find the `.pte` files under the directory named as same as the model. +### Partitioner API -### Runtime: +A list of CompileSpec is suppported by MediaTek backend: +- `platform-config`: Specifies the targeted MediaTek platform name to compile for. -**Build MediaTek Backend for ExecuTorch Runtime** -1. Navigate to `backends/mediatek/scripts/` directory. +## Runtime Integration -2. **Build MediaTek Backend**: Once the prerequisites are in place, run the `mtk_build.sh` script to start the build process: - ```bash - ./mtk_build.sh - ``` +This section presents an example of exporting and deploying a model. Please refer to `executorch/examples/mediatek/` for export and execution examples of various of models. -3. MediaTek backend will be built under `cmake-android-out/backends/` as `libneuron_backend.so`. +### Building Example Runners -**Build a runner to execute the model on the device**: -1. Build the runners and the backend by exedcuting the script: +Build example runners: ```bash ./mtk_build_examples.sh ``` +Runners are located in `cmake-android-out/examples/mediatek/`. -2. The runners will be built under `cmake-android-out/examples/` +### Deploying to Device -## Deploying and running on a device +1. Push `libneuron_backend.so`, `libneuronusdk_adapter.mtk.so` and `libneuron_buffer_allocator.so` to the device. +2. Set the library path before running ExecuTorch: + ```bash + export LD_LIBRARY_PATH=:::$LD_LIBRARY_PATH + ``` -1. **Push MediaTek universal SDK and MediaTek backend to the device**: push `libneuronusdk_adapter.mtk.so` and `libneuron_backend.so` to the phone and export it to the `$LD_LIBRARY_PATH` environment variable before executing ExecuTorch with MediaTek backend. +### Building the Backend from Source +1. Copy `NeuronAdapter.h` to `backends/mediatek/runtime/include/api/` +2. Set NDK Path: Ensure that the `$ANDROID_NDK` environment variable is set to the path where the NDK is located. ```bash - export LD_LIBRARY_PATH=::$LD_LIBRARY_PATH + export ANDROID_NDK= ``` + +3. Build the backend library `libneuron_backend.so`: + ```bash + cd backends/mediatek/scripts/ + ./mtk_build.sh + ``` +The output is `libneuron_backend.so` in `cmake-android-out/backends/mediatek/`. diff --git a/docs/source/backends-xnnpack.md b/docs/source/backends-xnnpack.md index 46ab379f186..b6bd1eab7c6 100644 --- a/docs/source/backends-xnnpack.md +++ b/docs/source/backends-xnnpack.md @@ -14,7 +14,7 @@ The XNNPACK delegate is the ExecuTorch solution for CPU execution on mobile CPUs - ARM64 on Android, iOS, macOS, Linux, and Windows. - ARMv7 (with NEON) on Android. - ARMv6 (with VFPv2) on Linux. -- x86 and x86-64 (up to AVX512) on Windows, Linux, macOS, Android, and iOS simulator. +- x86 and x86-64 (up to AVX512) on Windows, Linux, Android. ## Development Requirements diff --git a/docs/source/getting-started.md b/docs/source/getting-started.md index b7a97190b49..be15e7d6ea2 100644 --- a/docs/source/getting-started.md +++ b/docs/source/getting-started.md @@ -10,8 +10,9 @@ The following are required to install the ExecuTorch host libraries, needed to e - Python 3.10 - 3.12 - g++ version 7 or higher, clang++ version 5 or higher, or another C++17-compatible toolchain. -- Linux or MacOS operating system (Arm or x86). - - Windows is supported via WSL. +- Linux (x86_64 or ARM64) or macOS (ARM64). + - Intel-based macOS systems require building PyTorch from source (see [Building From Source](using-executorch-building-from-source.md) for instructions). + - Windows is supported via WSL. ## Installation To use ExecuTorch, you will need to install both the Python package and the appropriate platform-specific runtime libraries. Pip is the recommended way to install the ExecuTorch python package. diff --git a/docs/source/tutorial-xnnpack-delegate-lowering.md b/docs/source/tutorial-xnnpack-delegate-lowering.md index bbda61aadd8..04cab007f65 100644 --- a/docs/source/tutorial-xnnpack-delegate-lowering.md +++ b/docs/source/tutorial-xnnpack-delegate-lowering.md @@ -154,6 +154,7 @@ mkdir cmake-out cmake \ -DCMAKE_INSTALL_PREFIX=cmake-out \ -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ diff --git a/docs/source/using-executorch-building-from-source.md b/docs/source/using-executorch-building-from-source.md index f3aac4024af..7680e7bb743 100644 --- a/docs/source/using-executorch-building-from-source.md +++ b/docs/source/using-executorch-building-from-source.md @@ -16,7 +16,7 @@ Linux (x86_64) - Ubuntu 20.04.6 LTS+ - RHEL 8+ -macOS (x86_64/M1/M2) +macOS (x86_64/ARM64) - Big Sur (11.0)+ Windows (x86_64) @@ -56,13 +56,27 @@ Or alternatively, [install conda on your machine](https://conda.io/projects/cond conda create -yn executorch python=3.10.0 && conda activate executorch ``` -## Install ExecuTorch pip package from Source +## Install ExecuTorch pip package from source ```bash # Install ExecuTorch pip package and its dependencies, as well as # development tools like CMake. # If developing on a Mac, make sure to install the Xcode Command Line Tools first. + # Intel-based macOS systems require building PyTorch from source (see below) ./install_executorch.sh ``` + + See the [PyTorch instructions](https://github.com/pytorch/pytorch#installation) on how to build PyTorch from source. + + Use the [`--use-pt-pinned-commit` flag](../../install_executorch.py) to install ExecuTorch with an existing PyTorch build: + + ```bash + ./install_executorch.sh --use-pt-pinned-commit + ``` + + For Intel-based macOS systems, use the [`--use-pt-pinned-commit --minimal` flags](../../install_executorch.py): + ```bash + ./install_executorch.sh --use-pt-pinned-commit --minimal + ``` Not all backends are built into the pip wheel by default. You can link these missing/experimental backends by turning on the corresponding cmake flag. For example, to include the MPS backend: diff --git a/docs/source/using-executorch-ios.md b/docs/source/using-executorch-ios.md index 508669112f1..e3668a29e33 100644 --- a/docs/source/using-executorch-ios.md +++ b/docs/source/using-executorch-ios.md @@ -4,7 +4,7 @@ ExecuTorch supports both iOS and macOS via Objective-C, Swift, and C++. ExecuTor ## Integration -The ExecuTorch Runtime for iOS and macOS is distributed as a collection of prebuilt [.xcframework](https://developer.apple.com/documentation/xcode/creating-a-multi-platform-binary-framework-bundle) binary targets. These targets are compatible with both iOS and macOS devices and simulators and are available in both release and debug modes: +The ExecuTorch Runtime for iOS and macOS (ARM64) is distributed as a collection of prebuilt [.xcframework](https://developer.apple.com/documentation/xcode/creating-a-multi-platform-binary-framework-bundle) binary targets. These targets are compatible with both iOS and macOS devices and simulators and are available in both release and debug modes: * `executorch` - Main Runtime components * `backend_coreml` - Core ML backend diff --git a/examples/arm/CMakeLists.txt b/examples/arm/CMakeLists.txt index 4bae20d2c1f..58466faeca5 100644 --- a/examples/arm/CMakeLists.txt +++ b/examples/arm/CMakeLists.txt @@ -29,6 +29,8 @@ endif() set(_common_compile_options -Wno-deprecated-declarations -fPIC) +add_compile_options("-Wall" "-Werror") + # Let files say "include ". set(_common_include_directories ${EXECUTORCH_ROOT}/..) diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 5449ced09b9..148f9c1d477 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -385,6 +385,7 @@ def get_compile_spec( intermediates: Optional[str] = None, system_config: Optional[str] = None, memory_mode: Optional[str] = None, + quantize: bool = False, ) -> list[CompileSpec]: spec_builder = None if target.startswith("TOSA"): @@ -401,7 +402,11 @@ def get_compile_spec( extra_flags="--verbose-operators --verbose-cycle-estimate", ) elif "vgf" in target: - spec_builder = ArmCompileSpecBuilder().vgf_compile_spec() + if quantize: + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + else: + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") + spec_builder = ArmCompileSpecBuilder().vgf_compile_spec(tosa_spec) if intermediates is not None: spec_builder.dump_intermediate_artifacts_to(intermediates) @@ -700,6 +705,7 @@ def to_edge_TOSA_delegate( args.intermediates, args.system_config, args.memory_mode, + args.quantize, ) model_int8 = None @@ -739,6 +745,7 @@ def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_ args.intermediates, args.system_config, args.memory_mode, + args.quantize, ) model, exported_program = quantize_model( args, model, example_inputs, compile_spec diff --git a/examples/arm/executor_runner/arm_executor_runner.cpp b/examples/arm/executor_runner/arm_executor_runner.cpp index debc955dcc0..5944a1f081c 100644 --- a/examples/arm/executor_runner/arm_executor_runner.cpp +++ b/examples/arm/executor_runner/arm_executor_runner.cpp @@ -201,7 +201,13 @@ void et_pal_emit_log_message( const char* message, ET_UNUSED size_t length) { fprintf( - stderr, "%c [executorch:%s:%zu] %s\n", level, filename, line, message); + stderr, + "%c [executorch:%s:%zu %s()] %s\n", + level, + filename, + line, + function, + message); } /** @@ -643,29 +649,41 @@ int main(int argc, const char* argv[]) { ET_CHECK(status == Error::Ok); for (int i = 0; i < inputs.size(); ++i) { - Tensor t = inputs[i].toTensor(); - // The output might be collected and parsed so printf() is used instead - // of ET_LOG() here - for (int j = 0; j < inputs[i].toTensor().numel(); ++j) { - if (t.scalar_type() == ScalarType::Int) { - printf( - "Input[%d][%d]: (int) %d\n", - i, - j, - inputs[i].toTensor().const_data_ptr()[j]); - } else if (t.scalar_type() == ScalarType::Float) { - printf( - "Input[%d][%d]: (float) %f\n", - i, - j, - inputs[i].toTensor().const_data_ptr()[j]); - } else if (t.scalar_type() == ScalarType::Char) { - printf( - "Input[%d][%d]: (char) %d\n", - i, - j, - inputs[i].toTensor().const_data_ptr()[j]); + if (inputs[i].isTensor()) { + Tensor t = inputs[i].toTensor(); + // The output might be collected and parsed so printf() is used instead + // of ET_LOG() here + for (int j = 0; j < inputs[i].toTensor().numel(); ++j) { + if (t.scalar_type() == ScalarType::Int) { + printf( + "Input[%d][%d]: (int) %d\n", + i, + j, + inputs[i].toTensor().const_data_ptr()[j]); + } else if (t.scalar_type() == ScalarType::Float) { + printf( + "Input[%d][%d]: (float) %f\n", + i, + j, + inputs[i].toTensor().const_data_ptr()[j]); + } else if (t.scalar_type() == ScalarType::Char) { + printf( + "Input[%d][%d]: (char) %d\n", + i, + j, + inputs[i].toTensor().const_data_ptr()[j]); + } else if (t.scalar_type() == ScalarType::Bool) { + printf( + "Input[%d][%d]: (bool) %s (0x%x)\n", + i, + j, + inputs[i].toTensor().const_data_ptr()[j] ? "true" + : "false", + inputs[i].toTensor().const_data_ptr()[j]); + } } + } else { + printf("Input[%d]: Not Tensor\n", i); } } } @@ -760,6 +778,14 @@ int main(int argc, const char* argv[]) { i, j, outputs[i].toTensor().const_data_ptr()[j]); + } else if (t.scalar_type() == ScalarType::Bool) { + printf( + "Output[%d][%d]: (bool) %s (0x%x)\n", + i, + j, + outputs[i].toTensor().const_data_ptr()[j] ? "true " + : "false", + outputs[i].toTensor().const_data_ptr()[j]); } } #endif diff --git a/examples/demo-apps/android/LlamaDemo/README.md b/examples/demo-apps/android/LlamaDemo/README.md index 4b8cafd2d4e..8fed04d7ff5 100644 --- a/examples/demo-apps/android/LlamaDemo/README.md +++ b/examples/demo-apps/android/LlamaDemo/README.md @@ -154,7 +154,7 @@ curl -C - -Ls "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokeni # Create params.json file touch params.json echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json -python -m examples.models.llama.export_llama -c stories110M.pt -p params.json -d fp16 -n stories110m_h.pte -kv +python -m extension.llm.export.export_llm base.checkpoint=stories110M.pt base.params=params.json model.dtype_override="fp16" export.output_name=stories110m_h.pte model.use_kv_cache=True python -m pytorch_tokenizers.tools.llama2c.convert -t tokenizer.model -o tokenizer.bin ``` ### Push model diff --git a/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md b/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md index fb9df3c3375..360e92a5f30 100644 --- a/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md +++ b/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md @@ -97,7 +97,7 @@ cmake --build cmake-out/examples/models/llama -j16 --config Release ## Export Llama Model QNN backend currently supports exporting to these data types: fp32, int4/ int8 with PTQ, int4 with SpinQuant (Llama 3 only). -We also support export for different Qualcomm SoC. We have verified SM8650(V75) and SM8550(V73). To export for different SoC, add “--soc_model SM8550” in your export command. Without setting this flag, the export will default to SM8650. +We also support export for different Qualcomm SoC. We have verified SM8650(V75) and SM8550(V73). To export for different SoC, add "--soc_model SM8550" in your export command. Without setting this flag, the export will default to SM8650. ### Export with PTQ We support PTQ by default. The entire export may take ~20 minutes (Llama 3.1 8B). However, there is accuracy regression and we are working on improving it. @@ -106,12 +106,12 @@ We support PTQ by default. The entire export may take ~20 minutes (Llama 3.1 8B) Examples: ``` # 4 bits weight only quantize -python -m examples.models.llama.export_llama --checkpoint "${MODEL_DIR}/consolidated.00.pth" -p "${MODEL_DIR}/params.json" -kv --disable_dynamic_shape --qnn --pt2e_quantize qnn_16a4w -d fp32 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="test.pte” +python -m extension.llm.export.export_llm base.checkpoint="${MODEL_DIR}/consolidated.00.pth" base.params="${MODEL_DIR}/params.json" model.use_kv_cache=True model.enable_dynamic_shape=False backend.qnn.enabled=True backend.qnn.quantization="qnn_16a4w" model.dtype_override="fp32" base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' export.output_name="test.pte" ``` If the model is really big, it may require model sharding because the Qualcomm DSP is a 32bit system and has a 4GB size limit . For example for Llama 3 8B models, we need to shard the model into 4, but ExecuTorch still packages it into one PTE file. Here is an example: ``` # 8 bits quantization with 4 shards -python -m examples.models.llama.export_llama --checkpoint "${MODEL_DIR}/consolidated.00.pth" -p "${MODEL_DIR}/params.json" -kv --disable_dynamic_shape --qnn --pt2e_quantize qnn_8a8w -d fp32 --num_sharding 4 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="test.pte” +python -m extension.llm.export.export_llm base.checkpoint="${MODEL_DIR}/consolidated.00.pth" base.params="${MODEL_DIR}/params.json" model.use_kv_cache=True model.enable_dynamic_shape=False backend.qnn.enabled=True backend.qnn.quantization="qnn_8a8w" model.dtype_override="fp32" backend.qnn.num_sharding=4 base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' export.output_name="test.pte" ``` Note: if you encountered issues below ``` @@ -163,7 +163,7 @@ To export Llama 3 8B instruct with the Qualcomm AI Engine Direct Backend, ensure * 8B models might need 16GB RAM on the device to run. ``` # Please note that calibration_data must include the prompt template for special tokens. -python -m examples.models.llama.export_llama -t -p -c --use_kv_cache --qnn --pt2e_quantize qnn_16a4w --disable_dynamic_shape --num_sharding 8 --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --optimized_rotation_path --calibration_data "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" +python -m extension.llm.export.export_llm base.tokenizer= base.params= base.checkpoint= model.use_kv_cache=True backend.qnn.enabled=True backend.qnn.quantization="qnn_16a4w" model.enable_dynamic_shape=False backend.qnn.num_sharding=8 backend.qnn.calibration_tasks="wikitext" backend.qnn.calibration_limit=1 backend.qnn.calibration_seq_length=128 backend.qnn.optimized_rotation_path= backend.qnn.calibration_data="<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ``` ## Pushing Model and Tokenizer @@ -210,17 +210,17 @@ Alternative you can also just run the shell script directly as in the root direc sh examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh ``` This is running the shell script which configures the required core ExecuTorch, Llama2/3, and Android libraries, builds them into AAR, and copies it to the app. -Note: If you are building the Android app mentioned in the next section on a separate machine (i.e. MacOS but building and exporting for QNN backend on Linux), make sure you copy the aar file generated from setup-with-qnn script to “examples/demo-apps/android/LlamaDemo/app/libs” before building the Android app. +Note: If you are building the Android app mentioned in the next section on a separate machine (i.e. MacOS but building and exporting for QNN backend on Linux), make sure you copy the aar file generated from setup-with-qnn script to "examples/demo-apps/android/LlamaDemo/app/libs" before building the Android app. ## Run the Android Demo App -First, make sure your Android phone’s chipset version is compatible with this demo (SM8650, SM8550). You can find the Qualcomm chipset version here in the [mapping](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/overview.html). +First, make sure your Android phone's chipset version is compatible with this demo (SM8650, SM8550). You can find the Qualcomm chipset version here in the [mapping](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/overview.html). -If you build and run the setup-with-qnn script on a separate machine rather than where you are building the Android app, make sure you copy the aar file it generated into “examples/demo-apps/android/LlamaDemo/app/libs” +If you build and run the setup-with-qnn script on a separate machine rather than where you are building the Android app, make sure you copy the aar file it generated into "examples/demo-apps/android/LlamaDemo/app/libs" ### Alternative 1: Android Studio (Recommended) -Open Android Studio and select “Open an existing Android Studio project” to open examples/demo-apps/android/LlamaDemo. +Open Android Studio and select "Open an existing Android Studio project" to open examples/demo-apps/android/LlamaDemo. Run the app (^R). This builds and launches the app on the phone. ### Alternative 2: Command line @@ -238,4 +238,4 @@ If the app successfully run on your device, you should see something like below:

## Reporting Issues -If you encountered any bugs or issues following this tutorial please file a bug/issue here on Github. +If you encountered any bugs or issues following this tutorial please file a bug/issue here on Github. \ No newline at end of file diff --git a/examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md b/examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md index de99387f82d..baf8ffb7071 100644 --- a/examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md +++ b/examples/demo-apps/android/LlamaDemo/docs/delegates/xnnpack_README.md @@ -55,7 +55,7 @@ In this demo app, we support text-only inference with up-to-date Llama models an Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend. * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --max_context_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte" +python -m extension.llm.export.export_llm base.model_class="llama3_2" base.checkpoint= base.params= model.use_kv_cache=True model.use_sdpa_with_kv_cache=True backend.xnnpack.enabled=True model.dtype_override="fp32" backend.xnnpack.extended_ops=True base.preq_mode="8da4w_output_8da8w" base.preq_group_size=32 export.max_seq_length=2048 export.max_context_length=2048 base.preq_embedding_quantize=\'8,0\' quantization.use_spin_quant="native" base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' export.output_name="llama3_2_spinquant.pte" ``` For convenience, an [exported ExecuTorch SpinQuant model](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8-ET/blob/main/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8-ET/blob/main/Export_Recipe_Llama_3_2_1B_Instruct_SpinQuant_INT4_EO8.ipynb). @@ -63,7 +63,7 @@ For convenience, an [exported ExecuTorch SpinQuant model](https://huggingface.co Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend. * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --max_context_length 2048--preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte" +python -m extension.llm.export.export_llm base.model_class="llama3_2" base.checkpoint= base.params= quantization.use_qat=True base.use_lora=16 model.use_kv_cache=True model.use_sdpa_with_kv_cache=True backend.xnnpack.enabled=True model.dtype_override="fp32" backend.xnnpack.extended_ops=True base.preq_mode="8da4w_output_8da8w" base.preq_group_size=32 export.max_seq_length=2048 export.max_context_length=2048 base.preq_embedding_quantize=\'8,0\' base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' export.output_name="llama3_2_qat_lora.pte" ``` For convenience, an [exported ExecuTorch QAT+LoRA model](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-QLORA_INT4_EO8-ET/blob/main/Llama-3.2-1B-Instruct-QLORA_INT4_EO8.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-QLORA_INT4_EO8-ET/blob/main/Export_Recipe_Llama_3_2_1B_Instruct_QLORA_INT4_EO8.ipynb). @@ -74,7 +74,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte" +python -m extension.llm.export.export_llm base.model_class="llama3_2" base.checkpoint= base.params= model.use_kv_cache=True model.use_sdpa_with_kv_cache=True backend.xnnpack.enabled=True model.dtype_override="bf16" base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' export.output_name="llama3_2_bf16.pte" ``` For convenience, an [exported ExecuTorch bf16 model](https://huggingface.co/executorch-community/Llama-3.2-1B-ET/blob/main/llama3_2-1B.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-ET/blob/main/ExportRecipe_1B.ipynb). @@ -90,7 +90,7 @@ To safeguard your application, you can use our Llama Guard models for prompt cla * We prepared this model using the following command ``` -python -m examples.models.llama.export_llama --checkpoint --params -d fp32 -kv --use_sdpa_with_kv_cache --quantization_mode 8da4w --group_size 256 --xnnpack --max_seq_length 8193 --max_context_length 8193 --embedding-quantize 4,32 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_prune_map --output_name="llama_guard_3_1b_pruned_xnnpack.pte" +python -m extension.llm.export.export_llm base.checkpoint= base.params= model.dtype_override="fp32" model.use_kv_cache=True model.use_sdpa_with_kv_cache=True quantization.qmode="8da4w" quantization.group_size=256 backend.xnnpack.enabled=True export.max_seq_length=8193 export.max_context_length=8193 quantization.embedding_quantize=\'4,32\' base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' base.output_prune_map= export.output_name="llama_guard_3_1b_pruned_xnnpack.pte" ``` @@ -100,7 +100,7 @@ python -m examples.models.llama.export_llama --checkpoint --params -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama.pte" +python -m extension.llm.export.export_llm base.checkpoint= base.params= model.use_kv_cache=True model.use_sdpa_with_kv_cache=True backend.xnnpack.enabled=True quantization.qmode="8da4w" quantization.group_size=128 model.dtype_override="fp32" base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' export.output_name="llama.pte" ``` You may wonder what the ‘--metadata’ flag is doing. This flag helps export the model with proper special tokens added that the runner can detect EOS tokens easily. diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj b/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj index ddf7f32f043..042f3903c67 100644 --- a/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj +++ b/examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj @@ -547,7 +547,7 @@ ); runOnlyForDeploymentPostprocessing = 0; shellPath = /bin/sh; - shellScript = "set -e\n\nif ! command -v cmake &> /dev/null\nthen\n echo \"Cmake not found, please install Cmake. \\n1. Download Cmake.app from https://cmake.org/download with version > 3.19. \\n2. Install it to Applications/ folder and run `sudo /Applications/CMake.app/Contents/bin/cmake-gui --install` to install CMake commandline tools.\"\n exit 1\nfi\n\nCMAKE_DIR=\"$TEMP_DIR/cmake\"\nrm -rf \"$CMAKE_DIR\"\n\nPLATFORM=\"SIMULATORARM64\"\nDEPLOYMENT_TARGET=\"17.0\"\n\nif [[ \"$PLATFORM_NAME\" == *\"iphoneos\"* ]]; then\n PLATFORM=\"OS64\"\nelif [[ \"$PLATFORM_NAME\" == *\"macos\"* ]]; then\n PLATFORM=\"MAC_ARM64\"\n DEPLOYMENT_TARGET=\"12.0\"\nfi\n\ncmake_build() {\n local src_dir=$1\n local target=$2\n shift 2\n local extra_args=(\"$@\")\n local build_dir=\"$CMAKE_DIR/build/$(basename \"$src_dir\")\"\n\n mkdir -p \"$build_dir\" && cd \"$build_dir\"\n\n if [[ \"$PLATFORM\" == \"MAC_ARM64\" ]]; then\n extra_args+=(-DCMAKE_INSTALL_BUNDLEDIR=\"${CMAKE_DIR}/bin\")\n extra_args+=(-DCMAKE_MACOSX_BUNDLE=OFF)\n fi\n cmake -G Xcode \\\n -DCMAKE_BUILD_TYPE=\"Release\" \\\n -DCMAKE_CXX_STANDARD=17 \\\n -DCMAKE_TOOLCHAIN_FILE=\"$SRCROOT/../../../../third-party/ios-cmake/ios.toolchain.cmake\" \\\n -DCMAKE_XCODE_ATTRIBUTE_CLANG_CXX_LANGUAGE_STANDARD=\"c++17\" \\\n -DCMAKE_XCODE_ATTRIBUTE_CLANG_CXX_LIBRARY=\"libc++\" \\\n -DPLATFORM=\"$PLATFORM\" \\\n -DDEPLOYMENT_TARGET=\"$DEPLOYMENT_TARGET\" \\\n -DCMAKE_INSTALL_PREFIX=\"$CMAKE_DIR\" \\\n \"${extra_args[@]}\" \\\n \"$src_dir\"\n cmake --build . --config \"Release\" --target \"$target\"\n if [[ \"$target\" == \"install\" ]]; then\n cmake --install . --prefix \"$CMAKE_DIR\"\n fi\n}\n\ncmake_build \"$SRCROOT/../../../../extension/llm/tokenizers/third-party/abseil-cpp\" \"install\" \\\n -DABSL_PROPAGATE_CXX_STD=ON\n\ncmake_build \"$SRCROOT/../../../../extension/llm/tokenizers/third-party/re2\" \"install\"\n\ncmake_build \"$SRCROOT/../../../../extension/llm/tokenizers/third-party/pcre2\" \"install\" \\\n -DPCRE2_BUILD_PCRE2_8=ON \\\n -DPCRE2_BUILD_PCRE2_16=OFF \\\n -DPCRE2_BUILD_PCRE2_32=OFF \\\n -DPCRE2_BUILD_TESTS=OFF \\\n -DPCRE2_BUILD_PCRE2GREP=OFF \\\n -DPCRE2_BUILD_PCRE2TEST=OFF \\\n -DPCRE2_BUILD_PCRE2GPERF=OFF \\\n -DPCRE2_BUILD_DOCS=OFF \\\n -DPCRE2_BUILD_LIBPCRE2_PDB=OFF\n\ncmake_build \"$SRCROOT/../../../../extension/llm/tokenizers/third-party/sentencepiece\" \"sentencepiece-static\" \\\n -DSPM_ENABLE_SHARED=OFF\n \ncmake_build \"$SRCROOT/../../../../extension/llm/tokenizers/third-party/llama.cpp-unicode\" \"install\"\n \n# Include the single header for json.\nmkdir -p \"$CMAKE_DIR/include/nlohmann\"\ncp \"$SRCROOT/../../../../extension/llm/tokenizers/third-party/json/single_include/nlohmann/json.hpp\" \"$CMAKE_DIR/include/nlohmann/json.hpp\"\n\necho \"$(find $CMAKE_DIR/lib -name \"*.a\" | sed -E 's|^.*/lib([^/]+)\\.a|-l\\1|g' | tr '\\n' ' ')\" > \"$CMAKE_DIR/linker_flags\"\n"; + shellScript = "set -e\n\nif ! command -v cmake &> /dev/null\nthen\n echo \"Cmake not found, please install Cmake. \\n1. Download Cmake.app from https://cmake.org/download with version > 3.19. \\n2. Install it to Applications/ folder and run `sudo /Applications/CMake.app/Contents/bin/cmake-gui --install` to install CMake commandline tools.\"\n exit 1\nfi\n\nCMAKE_DIR=\"$TEMP_DIR/cmake\"\nrm -rf \"$CMAKE_DIR\"\n\nPLATFORM=\"SIMULATORARM64\"\nDEPLOYMENT_TARGET=\"17.0\"\n\nif [[ \"$PLATFORM_NAME\" == *\"iphoneos\"* ]]; then\n PLATFORM=\"OS64\"\nelif [[ \"$PLATFORM_NAME\" == *\"macos\"* ]]; then\n PLATFORM=\"MAC_ARM64\"\n DEPLOYMENT_TARGET=\"12.0\"\nfi\n\ncmake_build() {\n local src_dir target do_install=0\n local extra_args=()\n local build_dir\n # Parse arguments\n src_dir=\"$1\"\n shift\n target=\"$1\"\n if [[ \"$target\" == \"install\" ]]; then\n # Usage: cmake_build install [extra_args...]\n do_install=1\n shift\n else\n # Usage: cmake_build [install] [extra_args...]\n shift\n if [[ \"$1\" == \"install\" ]]; then\n do_install=1\n shift\n fi\n fi\n # Collect any remaining arguments as extra_args\n extra_args=(\"$@\")\n build_dir=\"$CMAKE_DIR/build/$(basename \"$src_dir\")\"\n mkdir -p \"$build_dir\" || { echo \"Failed to create build dir\"; return 1; }\n pushd \"$build_dir\" > /dev/null || { echo \"Failed to enter build dir\"; return 1; }\n # Platform-specific CMake args\n if [[ \"$PLATFORM\" == \"MAC_ARM64\" ]]; then\n extra_args+=(-DCMAKE_INSTALL_BUNDLEDIR=\"${CMAKE_DIR}/bin\")\n extra_args+=(-DCMAKE_MACOSX_BUNDLE=OFF)\n fi\n # Configure\n cmake -G Xcode \\\n -DCMAKE_BUILD_TYPE=\"Release\" \\\n -DCMAKE_CXX_STANDARD=17 \\\n -DCMAKE_TOOLCHAIN_FILE=\"$SRCROOT/../../../../third-party/ios-cmake/ios.toolchain.cmake\" \\\n -DCMAKE_XCODE_ATTRIBUTE_CLANG_CXX_LANGUAGE_STANDARD=\"c++17\" \\\n -DCMAKE_XCODE_ATTRIBUTE_CLANG_CXX_LIBRARY=\"libc++\" \\\n -DPLATFORM=\"$PLATFORM\" \\\n -DDEPLOYMENT_TARGET=\"$DEPLOYMENT_TARGET\" \\\n -DCMAKE_INSTALL_PREFIX=\"$CMAKE_DIR\" \\\n \"${extra_args[@]}\" \\\n \"$src_dir\" || { echo \"CMake configure failed\"; popd > /dev/null; return 1; }\n # Build\n cmake --build . --config \"Release\" --target $target\n # Install if requested\n if [[ $do_install -eq 1 ]]; then\n cmake --install . --prefix \"$CMAKE_DIR\" || echo \"Ignoring install failures\"\n fi\n}\n\ncmake_build \"$SRCROOT/../../../../extension/llm/tokenizers/third-party/abseil-cpp\" \"install\" \\\n -DABSL_PROPAGATE_CXX_STD=ON\n\ncmake_build \"$SRCROOT/../../../../extension/llm/tokenizers/third-party/re2\" \"install\"\n\ncmake_build \"$SRCROOT/../../../../extension/llm/tokenizers/third-party/pcre2\" \"install\" \\\n -DPCRE2_BUILD_PCRE2_8=ON \\\n -DPCRE2_BUILD_PCRE2_16=OFF \\\n -DPCRE2_BUILD_PCRE2_32=OFF \\\n -DPCRE2_BUILD_TESTS=OFF \\\n -DPCRE2_BUILD_PCRE2GREP=OFF \\\n -DPCRE2_BUILD_PCRE2TEST=OFF \\\n -DPCRE2_BUILD_PCRE2GPERF=OFF \\\n -DPCRE2_BUILD_DOCS=OFF \\\n -DPCRE2_BUILD_LIBPCRE2_PDB=OFF\n\ncmake_build \"$SRCROOT/../../../llm/tokenizers/third-party/sentencepiece\" \"sentencepiece-static sentencepiece_train-static\" \"install\" \\\n -DSPM_ENABLE_SHARED=OFF \\\n -DSPM_BUILD_TEST=OFF \\\n -DCMAKE_SYSTEM_NAME=\"iOS\"\n \ncmake_build \"$SRCROOT/../../../../extension/llm/tokenizers/third-party/llama.cpp-unicode\" \"install\"\n \n# Include the single header for json.\nmkdir -p \"$CMAKE_DIR/include/nlohmann\"\ncp \"$SRCROOT/../../../../extension/llm/tokenizers/third-party/json/single_include/nlohmann/json.hpp\" \"$CMAKE_DIR/include/nlohmann/json.hpp\"\n\necho \"$(find $CMAKE_DIR/lib -name \"*.a\" | sed -E 's|^.*/lib([^/]+)\\.a|-l\\1|g' | tr '\\n' ' ')\" > \"$CMAKE_DIR/linker_flags\"\n"; }; /* End PBXShellScriptBuildPhase section */ @@ -858,10 +858,6 @@ DYLIB_INSTALL_NAME_BASE = "@rpath"; ENABLE_MODULE_VERIFIER = YES; GCC_C_LANGUAGE_STANDARD = gnu17; - GCC_PREPROCESSOR_DEFINITIONS = ( - "ET_USE_TIKTOKEN=1", - "SUPPORT_REGEX_LOOKAHEAD=ON", - ); GCC_PREPROCESSOR_DEFINITIONS = "SUPPORT_REGEX_LOOKAHEAD=ON"; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_KEY_NSHumanReadableCopyright = ""; diff --git a/examples/demo-apps/apple_ios/LLaMA/docs/delegates/mps_README.md b/examples/demo-apps/apple_ios/LLaMA/docs/delegates/mps_README.md index 47352607bca..d6bccc0ef47 100644 --- a/examples/demo-apps/apple_ios/LLaMA/docs/delegates/mps_README.md +++ b/examples/demo-apps/apple_ios/LLaMA/docs/delegates/mps_README.md @@ -49,7 +49,7 @@ Install the required packages to export the model Export the model ``` -python -m examples.models.llama.export_llama --checkpoint "${MODEL_DIR}/consolidated.00.pth" --params "${MODEL_DIR}/params.json" -kv --use_sdpa_with_kv_cache --mps -d fp32 --disable_dynamic_shape -qmode 8da4w -G 32 +python -m extension.llm.export.export_llm base.checkpoint="${MODEL_DIR}/consolidated.00.pth" base.params="${MODEL_DIR}/params.json" model.use_kv_cache=True model.use_sdpa_with_kv_cache=True backend.mps.enabled=True model.dtype_override="fp32" model.enable_dynamic_shape=False quantization.qmode="8da4w" quantization.group_size=32 ``` ## Pushing Model and Tokenizer diff --git a/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md b/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md index bb33b50f8b7..6cca65339da 100644 --- a/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md +++ b/examples/demo-apps/apple_ios/LLaMA/docs/delegates/xnnpack_README.md @@ -51,7 +51,7 @@ In this demo app, we support text-only inference with up-to-date Llama models an Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend. * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --max_context_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte" +python -m extension.llm.export.export_llm base.model_class="llama3_2" base.checkpoint= base.params= model.use_kv_cache=True model.use_sdpa_with_kv_cache=True backend.xnnpack.enabled=True model.dtype_override="fp32" backend.xnnpack.extended_ops=True base.preq_mode="8da4w_output_8da8w" base.preq_group_size=32 export.max_seq_length=2048 export.max_context_length=2048 base.preq_embedding_quantize=\'8,0\' quantization.use_spin_quant="native" base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' export.output_name="llama3_2_spinquant.pte" ``` For convenience, an [exported ExecuTorch SpinQuant model](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8-ET/blob/main/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8-ET/blob/main/Export_Recipe_Llama_3_2_1B_Instruct_SpinQuant_INT4_EO8.ipynb). @@ -59,7 +59,7 @@ For convenience, an [exported ExecuTorch SpinQuant model](https://huggingface.co Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend. * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --max_context_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte" +python -m extension.llm.export.export_llm base.model_class="llama3_2" base.checkpoint= base.params= quantization.use_qat=True base.use_lora=16 model.use_kv_cache=True model.use_sdpa_with_kv_cache=True backend.xnnpack.enabled=True model.dtype_override="fp32" backend.xnnpack.extended_ops=True base.preq_mode="8da4w_output_8da8w" base.preq_group_size=32 export.max_seq_length=2048 export.max_context_length=2048 base.preq_embedding_quantize=\'8,0\' base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' export.output_name="llama3_2_qat_lora.pte" ``` For convenience, an [exported ExecuTorch QAT+LoRA model](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-QLORA_INT4_EO8-ET/blob/main/Llama-3.2-1B-Instruct-QLORA_INT4_EO8.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-QLORA_INT4_EO8-ET/blob/main/Export_Recipe_Llama_3_2_1B_Instruct_QLORA_INT4_EO8.ipynb). @@ -69,7 +69,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B * Export Llama model and generate .pte file as below: ``` -python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint --params -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte" +python -m extension.llm.export.export_llm base.model_class="llama3_2" base.checkpoint= base.params= model.use_kv_cache=True model.use_sdpa_with_kv_cache=True backend.xnnpack.enabled=True model.dtype_override="bf16" base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' export.output_name="llama3_2_bf16.pte" ``` For convenience, an [exported ExecuTorch bf16 model](https://huggingface.co/executorch-community/Llama-3.2-1B-ET/blob/main/llama3_2-1B.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-ET/blob/main/ExportRecipe_1B.ipynb). @@ -79,7 +79,7 @@ For more detail using Llama 3.2 lightweight models including prompt template, pl Export the model ``` -python -m examples.models.llama.export_llama --checkpoint -p -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --embedding-quantize 4,32 --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte" +python -m extension.llm.export.export_llm base.checkpoint= base.params= model.use_kv_cache=True model.use_sdpa_with_kv_cache=True backend.xnnpack.enabled=True quantization.qmode="8da4w" quantization.group_size=128 model.dtype_override="fp32" base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' quantization.embedding_quantize=\'4,32\' export.output_name="llama3_kv_sdpa_xnn_qe_4_32.pte" ``` ### For LLaVA model diff --git a/examples/models/deepseek-r1-distill-llama-8B/README.md b/examples/models/deepseek-r1-distill-llama-8B/README.md index 5fd47ad61ec..f05dd9990a2 100644 --- a/examples/models/deepseek-r1-distill-llama-8B/README.md +++ b/examples/models/deepseek-r1-distill-llama-8B/README.md @@ -52,18 +52,18 @@ torch.save(sd, "/tmp/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/checkpoint.pth") 5. Generate a PTE file for use with the Llama runner. ``` -python -m examples.models.llama.export_llama \ - --checkpoint /tmp/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/checkpoint.pth \ - -p params.json \ - -kv \ - --use_sdpa_with_kv_cache \ - -X \ - -qmode 8da4w \ - --group_size 128 \ - -d fp16 \ - --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \ - --embedding-quantize 4,32 \ - --output_name="DeepSeek-R1-Distill-Llama-8B.pte" +python -m extension.llm.export.export_llm \ + base.checkpoint=/tmp/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/checkpoint.pth \ + base.params=params.json \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + backend.xnnpack.enabled=True \ + quantization.qmode="8da4w" \ + quantization.group_size=128 \ + model.dtype_override="fp16" \ + base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' \ + quantization.embedding_quantize=\'4,32\' \ + export.output_name="DeepSeek-R1-Distill-Llama-8B.pte" ``` 6. Run the model on your desktop for validation or integrate with iOS/Android apps. Instructions for these are available in the Llama [README](../llama/README.md) starting at Step 3. diff --git a/examples/models/llama/README.md b/examples/models/llama/README.md index c6f0350fff7..e555043c44d 100644 --- a/examples/models/llama/README.md +++ b/examples/models/llama/README.md @@ -167,15 +167,15 @@ Llama 3 8B performance was measured on the Samsung Galaxy S22, S24, and OnePlus LLAMA_CHECKPOINT=path/to/consolidated.00.pth LLAMA_PARAMS=path/to/params.json -python -m examples.models.llama.export_llama \ - --model "llama3_2" \ - --checkpoint "${LLAMA_CHECKPOINT:?}" \ - --params "${LLAMA_PARAMS:?}" \ - -kv \ - --use_sdpa_with_kv_cache \ - -d bf16 \ - --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \ - --output_name="llama3_2.pte" +python -m extension.llm.export.export_llm \ + base.model_class="llama3_2" \ + base.checkpoint="${LLAMA_CHECKPOINT:?}" \ + base.params="${LLAMA_PARAMS:?}" \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + model.dtype_override="bf16" \ + base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' \ + export.output_name="llama3_2.pte" ``` For convenience, an [exported ExecuTorch bf16 model](https://huggingface.co/executorch-community/Llama-3.2-1B-ET/blob/main/llama3_2-1B.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-ET/blob/main/ExportRecipe_1B.ipynb). @@ -189,23 +189,23 @@ For convenience, an [exported ExecuTorch bf16 model](https://huggingface.co/exec LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/consolidated.00.pth.pth LLAMA_PARAMS=path/to/spinquant/params.json -python -m examples.models.llama.export_llama \ - --model "llama3_2" \ - --checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \ - --params "${LLAMA_PARAMS:?}" \ - --use_sdpa_with_kv_cache \ - -X \ - --xnnpack-extended-ops \ - --preq_mode 8da4w_output_8da8w \ - --preq_group_size 32 \ - --max_seq_length 2048 \ - --max_context_length 2048 \ - --output_name "llama3_2.pte" \ - -kv \ - -d fp32 \ - --preq_embedding_quantize 8,0 \ - --use_spin_quant native \ - --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' +python -m extension.llm.export.export_llm \ + base.model_class="llama3_2" \ + base.checkpoint="${LLAMA_QUANTIZED_CHECKPOINT:?}" \ + base.params="${LLAMA_PARAMS:?}" \ + model.use_sdpa_with_kv_cache=True \ + backend.xnnpack.enabled=True \ + backend.xnnpack.extended_ops=True \ + base.preq_mode="8da4w_output_8da8w" \ + base.preq_group_size=32 \ + export.max_seq_length=2048 \ + export.max_context_length=2048 \ + export.output_name="llama3_2.pte" \ + model.use_kv_cache=True \ + model.dtype_override="fp32" \ + base.preq_embedding_quantize=\'8,0\' \ + quantization.use_spin_quant="native" \ + base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' ``` For convenience, an [exported ExecuTorch SpinQuant model](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8-ET/blob/main/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8-ET/blob/main/Export_Recipe_Llama_3_2_1B_Instruct_SpinQuant_INT4_EO8.ipynb). @@ -218,24 +218,24 @@ For convenience, an [exported ExecuTorch SpinQuant model](https://huggingface.co LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/consolidated.00.pth.pth LLAMA_PARAMS=path/to/qlora/params.json -python -m examples.models.llama.export_llama \ - --model "llama3_2" \ - --checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \ - --params "${LLAMA_PARAMS:?}" \ - -qat \ - -lora 16 \ - --preq_mode 8da4w_output_8da8w \ - --preq_group_size 32 \ - --preq_embedding_quantize 8,0 \ - --use_sdpa_with_kv_cache \ - -kv \ - -X \ - --xnnpack-extended-ops \ - -d fp32 \ - --max_seq_length 2048 \ - --max_context_length 2048 \ - --output_name "llama3_2.pte" \ - --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' +python -m extension.llm.export.export_llm \ + base.model_class="llama3_2" \ + base.checkpoint="${LLAMA_QUANTIZED_CHECKPOINT:?}" \ + base.params="${LLAMA_PARAMS:?}" \ + quantization.use_qat=True \ + base.use_lora=16 \ + base.preq_mode="8da4w_output_8da8w" \ + base.preq_group_size=32 \ + base.preq_embedding_quantize=\'8,0\' \ + model.use_sdpa_with_kv_cache=True \ + model.use_kv_cache=True \ + backend.xnnpack.enabled=True \ + backend.xnnpack.extended_ops=True \ + model.dtype_override="fp32" \ + export.max_seq_length=2048 \ + export.max_context_length=2048 \ + export.output_name="llama3_2.pte" \ + base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' ``` For convenience, an [exported ExecuTorch QAT+LoRA model](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-QLORA_INT4_EO8-ET/blob/main/Llama-3.2-1B-Instruct-QLORA_INT4_EO8.pte) is available on Hugging Face. The export was created using [this detailed recipe notebook](https://huggingface.co/executorch-community/Llama-3.2-1B-Instruct-QLORA_INT4_EO8-ET/blob/main/Export_Recipe_Llama_3_2_1B_Instruct_QLORA_INT4_EO8.ipynb). @@ -247,20 +247,20 @@ You can export and run the original Llama 3 8B instruct model. 2. Export model and generate `.pte` file ``` - python -m examples.models.llama.export_llama \ - --checkpoint \ - -p \ - -kv \ - --use_sdpa_with_kv_cache \ - -X \ - -qmode 8da4w \ - --group_size 128 \ - -d fp32 \ - --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \ - --embedding-quantize 4,32 \ - --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte" + python -m extension.llm.export.export_llm \ + base.checkpoint= \ + base.params= \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + backend.xnnpack.enabled=True \ + quantization.qmode="8da4w" \ + quantization.group_size=128 \ + model.dtype_override="fp32" \ + base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' \ + quantization.embedding_quantize=\'4,32\' \ + export.output_name="llama3_kv_sdpa_xnn_qe_4_32.pte" ``` - Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `--embedding-quantize 4,32` as shown above to further reduce the model size. + Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `quantization.embedding_quantize=\'4,32\'` as shown above to further reduce the model size. If you're interested in deploying on non-CPU backends, [please refer the non-cpu-backend section](non_cpu_backends.md) @@ -389,22 +389,22 @@ QLINEAR_GROUP_SIZE=128 # Must be multiple of 16 QEMBEDDING_BITWIDTH=4 # Can be 1-8 QEMBEDDING_GROUP_SIZE=32 # Must be multiple of 16 -python -m examples.models.llama.export_llama \ - --model "llama3_2" \ - --checkpoint "${LLAMA_CHECKPOINT:?}" \ - --params "${LLAMA_PARAMS:?}" \ - -kv \ - --use_sdpa_with_kv_cache \ - --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \ - --output_name="llama3_2.pte" \ - -qmode "torchao:8da${QLINEAR_BITWIDTH}w" \ - --group_size ${QLINEAR_GROUP_SIZE} \ - -E "torchao:${QEMBEDDING_BITWIDTH},${QEMBEDDING_GROUP_SIZE}" \ - -d fp32 +python -m extension.llm.export.export_llm \ + base.model_class="llama3_2" \ + base.checkpoint="${LLAMA_CHECKPOINT:?}" \ + base.params="${LLAMA_PARAMS:?}" \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + base.metadata='"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"' \ + export.output_name="llama3_2.pte" \ + quantization.qmode="torchao:8da${QLINEAR_BITWIDTH}w" \ + quantization.group_size=${QLINEAR_GROUP_SIZE} \ + quantization.embedding_quantize=\'torchao:${QEMBEDDING_BITWIDTH},${QEMBEDDING_GROUP_SIZE}\' \ + model.dtype_override="fp32" ``` A few notes: -- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized asymmetrically or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and is asymmetric (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32 and is symmetric. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations. +- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `model.use_shared_embedding=True` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized asymmetrically or not by specifying a third argument. For example, `quantization.embedding_quantize="torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and is asymmetric (this is the default behavior if you simply use `quantization.embedding_quantize="torchao:4,32"`), whereas `quantization.embedding_quantize="torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32 and is symmetric. If `model.use_shared_embedding=True` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations. - To do channelwise quantization, specify group_size to 0. This works for both linear and embedding layers. Once the model is exported, we need to build ExecuTorch and the runner with the low-bit kernels. diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 86b7e957628..d2caccd5897 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -85,6 +85,7 @@ runtime.python_binary( ":export_library", "//caffe2:torch", "//executorch/extension/pybindings:aten_lib", + "//executorch/extension/llm/export:export_llm_lib", ], ) @@ -133,8 +134,6 @@ runtime.python_library( name = "export_library", srcs = [ "export_llama.py", - "export_llama_args.py", - "export_llama_hydra.py", "export_llama_lib.py", "model.py", ], diff --git a/examples/models/llama/UTILS.md b/examples/models/llama/UTILS.md index 5f760ad7670..25bd7f77080 100644 --- a/examples/models/llama/UTILS.md +++ b/examples/models/llama/UTILS.md @@ -19,7 +19,7 @@ From `executorch` root: ``` 3. Export model and generate `.pte` file. ``` - python -m examples.models.llama.export_llama -c stories110M.pt -p params.json -X -kv + python -m extension.llm.export.export_llm base.checkpoint=stories110M.pt base.params=params.json backend.xnnpack.enabled=True model.use_kv_cache=True ``` ## Smaller model delegated to other backends @@ -27,15 +27,15 @@ From `executorch` root: Currently we supported lowering the stories model to other backends, including, CoreML, MPS and QNN. Please refer to the instruction for each backend ([CoreML](https://pytorch.org/executorch/main/backends-coreml), [MPS](https://pytorch.org/executorch/main/backends-mps), [QNN](https://pytorch.org/executorch/main/backends-qualcomm)) before trying to lower them. After the backend library is installed, the script to export a lowered model is -- Lower to CoreML: `python -m examples.models.llama.export_llama -kv --disable_dynamic_shape --coreml -c stories110M.pt -p params.json ` -- MPS: `python -m examples.models.llama.export_llama -kv --disable_dynamic_shape --mps -c stories110M.pt -p params.json ` -- QNN: `python -m examples.models.llama.export_llama -kv --disable_dynamic_shape --qnn -c stories110M.pt -p params.json ` +- Lower to CoreML: `python -m extension.llm.export.export_llm model.use_kv_cache=True model.enable_dynamic_shape=False backend.coreml.enabled=True base.checkpoint=stories110M.pt base.params=params.json` +- MPS: `python -m extension.llm.export.export_llm model.use_kv_cache=True model.enable_dynamic_shape=False backend.mps.enabled=True base.checkpoint=stories110M.pt base.params=params.json` +- QNN: `python -m extension.llm.export.export_llm model.use_kv_cache=True model.enable_dynamic_shape=False backend.qnn.enabled=True base.checkpoint=stories110M.pt base.params=params.json` The iOS LLAMA app supports the CoreML and MPS model and the Android LLAMA app supports the QNN model. On Android, it also allow to cross compiler the llama runner binary, push to the device and run. For CoreML, there are 2 additional optional arguments: -* `--coreml-ios`: Specify the minimum iOS version to deploy (and turn on available optimizations). E.g. `--coreml-ios 18` will turn on [in-place KV cache](https://developer.apple.com/documentation/coreml/mlstate?language=objc) and [fused scaled dot product attention kernel](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS18.transformers.scaled_dot_product_attention) (the resulting model will then need at least iOS 18 to run, though) -* `--coreml-quantize`: Use [quantization tailored for CoreML](https://apple.github.io/coremltools/docs-guides/source/opt-quantization-overview.html). E.g. `--coreml-quantize b4w` will perform per-block 4-bit weight-only quantization in a way tailored for CoreML +* `backend.coreml.ios`: Specify the minimum iOS version to deploy (and turn on available optimizations). E.g. `backend.coreml.ios=18` will turn on [in-place KV cache](https://developer.apple.com/documentation/coreml/mlstate?language=objc) and [fused scaled dot product attention kernel](https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS18.transformers.scaled_dot_product_attention) (the resulting model will then need at least iOS 18 to run, though) +* `backend.coreml.quantize`: Use [quantization tailored for CoreML](https://apple.github.io/coremltools/docs-guides/source/opt-quantization-overview.html). E.g. `backend.coreml.quantize="b4w"` will perform per-block 4-bit weight-only quantization in a way tailored for CoreML To deploy the large 8B model on the above backends, [please visit this section](non_cpu_backends.md). diff --git a/examples/models/llama/config/llm_config.py b/examples/models/llama/config/llm_config.py index a5c486a8c1e..9acd633fb21 100644 --- a/examples/models/llama/config/llm_config.py +++ b/examples/models/llama/config/llm_config.py @@ -26,19 +26,19 @@ class ModelType(str, Enum): - STORIES110M = "stories110m" - LLAMA2 = "llama2" - LLAMA3 = "llama3" - LLAMA3_1 = "llama3_1" - LLAMA3_2 = "llama3_2" - LLAMA3_2_VISION = "llama3_2_vision" - STATIC_LLAMA = "static_llama" - QWEN2_5 = "qwen2_5" - QWEN3_0_6B = "qwen3-0_6b" - QWEN3_1_7B = "qwen3-1_7b" - QWEN3_4B = "qwen3-4b" - PHI_4_MINI = "phi_4_mini" - SMOLLM2 = "smollm2" + stories110m = "stories110m" + llama2 = "llama2" + llama3 = "llama3" + llama3_1 = "llama3_1" + llama3_2 = "llama3_2" + llama3_2_vision = "llama3_2_vision" + static_llama = "static_llama" + qwen2_5 = "qwen2_5" + qwen3_0_6b = "qwen3-0_6b" + qwen3_1_7b = "qwen3-1_7b" + qwen3_4b = "qwen3-4b" + phi_4_mini = "phi_4_mini" + smollm2 = "smollm2" class PreqMode(str, Enum): @@ -49,8 +49,8 @@ class PreqMode(str, Enum): are still around to preserve backward compatibility. """ - PREQ_8DA4W = "8da4w" - PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w" + preq_8da4w = "8da4w" + preq_8da4w_out_8da8w = "8da4w_output_8da8w" @dataclass @@ -65,7 +65,9 @@ class BaseConfig: params: Model parameters, such as n_layers, hidden_size, etc. If left empty will use defaults specified in model_args.py. checkpoint: Path to the checkpoint file. - If left empty, the model will be initialized with random weights. + If left empty, the model will either be initialized with random weights + if it is a Llama model or the weights will be downloaded from HuggingFace + if it is a non-Llama model. checkpoint_dir: Path to directory containing sharded checkpoint files. tokenizer_path: Path to the tokenizer file. metadata: Json string containing metadata information. @@ -80,13 +82,13 @@ class BaseConfig: are loaded. """ - model_class: ModelType = ModelType.LLAMA3 + model_class: ModelType = ModelType.llama3 params: Optional[str] = None checkpoint: Optional[str] = None checkpoint_dir: Optional[str] = None tokenizer_path: Optional[str] = None metadata: Optional[str] = None - use_lora: int = int + use_lora: int = 0 fairseq2: bool = False preq_mode: Optional[PreqMode] = None preq_group_size: int = 32 @@ -105,9 +107,9 @@ class DtypeOverride(str, Enum): is not recommended. """ - FP32 = "fp32" - FP16 = "fp16" - BF16 = "bf16" + fp32 = "fp32" + fp16 = "fp16" + bf16 = "bf16" @dataclass @@ -145,7 +147,7 @@ class ModelConfig: [16] pattern specifies all layers have a sliding window of 16. """ - dtype_override: DtypeOverride = DtypeOverride.FP32 + dtype_override: DtypeOverride = DtypeOverride.fp32 enable_dynamic_shape: bool = True use_shared_embedding: bool = False use_sdpa_with_kv_cache: bool = False @@ -214,7 +216,7 @@ class ExportConfig: max_seq_length: int = 128 max_context_length: int = 128 - output_dir: Optional[str] = None + output_dir: str = "." output_name: Optional[str] = None so_library: Optional[str] = None export_only: bool = False @@ -268,22 +270,22 @@ class Pt2eQuantize(str, Enum): and is source transform-based. """ - XNNPACK_DYNAMIC = "xnnpack_dynamic" - XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4" - QNN_8A8W = "qnn_8a8w" - QNN_16A16W = "qnn_16a16w" - QNN_16A4W = "qnn_16a4w" - COREML_C4W = "coreml_c4w" - COREML_8A_C8W = "coreml_8a_c8w" - COREML_8A_C4W = "coreml_8a_c4w" - COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w" - COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w" - VULKAN_8W = "vulkan_8w" + xnnpack_dynamic = "xnnpack_dynamic" + xnnpack_dynamic_qc4 = "xnnpack_dynamic_qc4" + qnn_8a8w = "qnn_8a8w" + qnn_16a16w = "qnn_16a16w" + qnn_16a4w = "qnn_16a4w" + coreml_c4w = "coreml_c4w" + coreml_8a_c8w = "coreml_8a_c8w" + coreml_8a_c4w = "coreml_8a_c4w" + coreml_baseline_8a_c8w = "coreml_baseline_8a_c8w" + coreml_baseline_8a_c4w = "coreml_baseline_8a_c4w" + vulkan_8w = "vulkan_8w" class SpinQuant(str, Enum): - CUDA = "cuda" - NATIVE = "native" + cuda = "cuda" + native = "native" @dataclass @@ -376,15 +378,15 @@ class XNNPackConfig: class CoreMLQuantize(str, Enum): - B4W = "b4w" - C4W = "c4w" + b4w = "b4w" + c4w = "c4w" class CoreMLComputeUnit(str, Enum): - CPU_ONLY = "cpu_only" - CPU_AND_GPU = "cpu_and_gpu" - CPU_AND_NE = "cpu_and_ne" - ALL = "all" + cpu_only = "cpu_only" + cpu_and_gpu = "cpu_and_gpu" + cpu_and_ne = "cpu_and_ne" + all = "all" @dataclass @@ -398,7 +400,7 @@ class CoreMLConfig: preserve_sdpa: bool = False quantize: Optional[CoreMLQuantize] = None ios: int = 15 - compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY + compute_units: CoreMLComputeUnit = CoreMLComputeUnit.cpu_only def __post_init__(self): if self.ios not in (15, 16, 17, 18): diff --git a/examples/models/llama/export_llama.py b/examples/models/llama/export_llama.py index 63e76e28ba9..93782b00e37 100644 --- a/examples/models/llama/export_llama.py +++ b/examples/models/llama/export_llama.py @@ -17,6 +17,11 @@ import torch +from executorch.examples.models.llama.export_llama_lib import ( + build_args_parser, + export_llama, +) + sys.setrecursionlimit(4096) @@ -39,15 +44,12 @@ def main() -> None: sys.argv = [arg for arg in sys.argv if arg != "--hydra"] print(f"running with {sys.argv}") runpy.run_module( - "executorch.examples.models.llama.export_llama_hydra", run_name="__main__" + "executorch.extension.llm.export.export_llm", run_name="__main__" ) else: - # Use the legacy version of the export_llama script which uses argsparse. - from executorch.examples.models.llama.export_llama_args import ( - main as export_llama_args_main, - ) - - export_llama_args_main(remaining_args) + parser = build_args_parser() + remaining_args = parser.parse_args(remaining_args) + export_llama(remaining_args) if __name__ == "__main__": diff --git a/examples/models/llama/export_llama_args.py b/examples/models/llama/export_llama_args.py deleted file mode 100644 index 7a176d9b7d0..00000000000 --- a/examples/models/llama/export_llama_args.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Run export_llama with the legacy argparse setup. -""" - -from .export_llama_lib import build_args_parser, export_llama - - -def main(args) -> None: - parser = build_args_parser() - args = parser.parse_args(args) - export_llama(args) - - -if __name__ == "__main__": - main() diff --git a/examples/models/llama/export_llama_hydra.py b/examples/models/llama/export_llama_hydra.py deleted file mode 100644 index 4871de00e25..00000000000 --- a/examples/models/llama/export_llama_hydra.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Run export_llama using the new Hydra CLI. -""" - -import hydra - -from executorch.examples.models.llama.config.llm_config import LlmConfig -from executorch.examples.models.llama.export_llama_lib import export_llama -from hydra.core.config_store import ConfigStore -from omegaconf import OmegaConf - -cs = ConfigStore.instance() -cs.store(name="llm_config", node=LlmConfig) - - -@hydra.main(version_base=None, config_name="llm_config") -def main(llm_config: LlmConfig) -> None: - export_llama(OmegaConf.to_object(llm_config)) - - -if __name__ == "__main__": - main() diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 1f055d65822..334f3ace712 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -53,6 +53,8 @@ ) from executorch.util.activation_memory_profiler import generate_memory_trace +from omegaconf import DictConfig + from ..model_factory import EagerModelFactory from .source_transformation.apply_spin_quant_r1_r2 import ( fuse_layer_norms, @@ -571,12 +573,14 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: def export_llama( - export_options: Union[argparse.Namespace, LlmConfig], + export_options: Union[argparse.Namespace, LlmConfig, DictConfig], ) -> str: if isinstance(export_options, argparse.Namespace): # Legacy CLI. llm_config = LlmConfig.from_args(export_options) - elif isinstance(export_options, LlmConfig): + elif isinstance(export_options, LlmConfig) or isinstance( + export_options, DictConfig + ): # Hydra CLI. llm_config = export_options else: @@ -586,7 +590,7 @@ def export_llama( # If a checkpoint isn't provided for an HF OSS model, download and convert the # weights first. - model_name = llm_config.base.model_class + model_name = llm_config.base.model_class.value if not llm_config.base.checkpoint and model_name in HUGGING_FACE_REPO_IDS: repo_id = HUGGING_FACE_REPO_IDS[model_name] if model_name == "qwen2_5": @@ -664,7 +668,7 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: llm_config.export.output_dir = output_dir_path # Convert dtype override string to actual type. - dtype_override = DType[llm_config.model.dtype_override] + dtype_override = DType[llm_config.model.dtype_override.value] edge_manager = _load_llama_model(llm_config) @@ -698,7 +702,11 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: checkpoint=llm_config.base.checkpoint, checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore tokenizer_path=llm_config.base.tokenizer_path, - use_spin_quant=llm_config.quantization.use_spin_quant, + use_spin_quant=( + llm_config.quantization.use_spin_quant.value + if llm_config.quantization.use_spin_quant + else None + ), embedding_quantize=llm_config.quantization.embedding_quantize, use_shared_embedding=llm_config.model.use_shared_embedding, quantization_mode=llm_config.quantization.qmode, @@ -722,7 +730,9 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: vulkan=llm_config.backend.vulkan.enabled, use_qat=llm_config.quantization.use_qat, use_lora=llm_config.base.use_lora, - preq_mode=llm_config.base.preq_mode, + preq_mode=( + llm_config.base.preq_mode.value if llm_config.base.preq_mode else None + ), preq_group_size=llm_config.base.preq_group_size, preq_embedding_quantize=llm_config.base.preq_embedding_quantize, local_global_attention=llm_config.model.local_global_attention, @@ -734,25 +744,34 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: def get_quantizer_and_quant_params(llm_config): pt2e_quant_params = get_pt2e_quantization_params( - llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode + ( + llm_config.quantization.pt2e_quantize.value + if llm_config.quantization.pt2e_quantize + else None + ), + llm_config.quantization.qmode, ) quantizers = get_pt2e_quantizers(pt2e_quant_params, llm_config.export.so_library) quant_dtype = None if llm_config.backend.qnn.enabled and llm_config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack and qnn" qnn_quantizer, quant_dtype = get_qnn_quantizer( - llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode + llm_config.quantization.pt2e_quantize.value, llm_config.quantization.qmode ) quantizers.append(qnn_quantizer) if llm_config.backend.coreml.enabled and llm_config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml" - coreml_quantizer = get_coreml_quantizer(llm_config.quantization.pt2e_quantize) + coreml_quantizer = get_coreml_quantizer( + llm_config.quantization.pt2e_quantize.value + ) quantizers.append(coreml_quantizer) if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize: assert ( len(quantizers) == 0 ), "Should not enable both vulkan and other quantizers" - vulkan_quantizer = get_vulkan_quantizer(llm_config.quantization.pt2e_quantize) + vulkan_quantizer = get_vulkan_quantizer( + llm_config.quantization.pt2e_quantize.value + ) quantizers.append(vulkan_quantizer) logging.info(f"Applying quantizers: {quantizers}") return pt2e_quant_params, quantizers, quant_dtype @@ -914,6 +933,7 @@ def _to_edge_and_lower_llama( # noqa: C901 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm._passes` from executorch.backends.qualcomm._passes import ( AnnotateStack, + ConvertBmmToMatmul, FoldQDQ, RecomposeRmsNorm, TagQuantIO, @@ -956,6 +976,7 @@ def _to_edge_and_lower_llama( # noqa: C901 passes_job = get_capture_program_passes() dep_table = get_passes_dependency_for_capture_program() passes_job[AnnotateStack][QCOM_PASS_ACTIVATE_KEY] = True + passes_job[ConvertBmmToMatmul][QCOM_PASS_ACTIVATE_KEY] = True passes_job[RecomposeRmsNorm][QCOM_PASS_ACTIVATE_KEY] = True passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ @@ -1029,7 +1050,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 ) additional_passes = [] - if llm_config.base.model_class in TORCHTUNE_DEFINED_MODELS: + if llm_config.base.model_class.value in TORCHTUNE_DEFINED_MODELS: additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] # export_to_edge @@ -1068,14 +1089,22 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 mps=llm_config.backend.mps.enabled, coreml=llm_config.backend.coreml.enabled, qnn=llm_config.backend.qnn.enabled, - dtype_override=llm_config.model.dtype_override, + dtype_override=llm_config.model.dtype_override.value, enable_dynamic_shape=llm_config.model.enable_dynamic_shape, use_kv_cache=llm_config.model.use_kv_cache, embedding_quantize=llm_config.quantization.embedding_quantize, - pt2e_quantize=llm_config.quantization.pt2e_quantize, + pt2e_quantize=( + llm_config.quantization.pt2e_quantize.value + if llm_config.quantization.pt2e_quantize + else None + ), coreml_ios=llm_config.backend.coreml.ios, - coreml_quantize=llm_config.backend.coreml.quantize, - coreml_compute_units=llm_config.backend.coreml.compute_units, + coreml_quantize=( + llm_config.backend.coreml.quantize.value + if llm_config.backend.coreml.quantize + else None + ), + coreml_compute_units=llm_config.backend.coreml.compute_units.value, use_qnn_sha=llm_config.backend.qnn.use_sha, num_sharding=llm_config.backend.qnn.num_sharding, soc_model=llm_config.backend.qnn.soc_model, @@ -1148,7 +1177,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": An instance of LLMEdgeManager which contains the eager mode model. """ - modelname = llm_config.base.model_class + modelname = llm_config.base.model_class.value if modelname in EXECUTORCH_DEFINED_MODELS: module_name = "llama" model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py. @@ -1169,7 +1198,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": ) ) # Convert dtype override string to actual type. - dtype_override = DType[llm_config.model.dtype_override] + dtype_override = DType[llm_config.model.dtype_override.value] return LLMEdgeManager( model=model, diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index 38009dd59ec..5d34bf932e7 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -42,6 +42,16 @@ DEFINE_int32( -1, "Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device."); +DEFINE_int32( + num_bos, + 0, + "Number of BOS tokens to prepend to the prompt. Defaults to 0. If > 0, the prompt will be prepended with BOS tokens. This is useful for models that expect one or more BOS token at the start."); + +DEFINE_int32( + num_eos, + 0, + "Number of EOS tokens to append to the prompt. Defaults to 0. If > 0, the prompt will be appended with EOS tokens. This is useful for models that expect one or more EOS token at the end."); + DEFINE_bool(warmup, false, "Whether to run a warmup run."); int32_t main(int32_t argc, char** argv) { diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index ec9646be6f4..efea80dde2f 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -157,7 +157,7 @@ def __init__(self, llm_config: Optional[LlmConfig] = None): if model_args.use_scaled_rope: # Older models don't have use_scaled_rope configuration - model_name = str(self.llm_config.base.model_class) + model_name = self.llm_config.base.model_class.value assert model_name not in ["llama2", "stories110m"] # Llama3_2 and newer models in ExecuTorch repo should set larger scale factor @@ -328,10 +328,10 @@ def get_example_inputs_kvcache_sdpa(self): def _transform_for_pre_quantization(self, checkpoint, model_args): assert self.llm_config.base.preq_mode, "preq_mode must be specified" - assert self.llm_config.base.preq_mode in [ + assert self.llm_config.base.preq_mode.value in [ "8da4w", "8da4w_output_8da8w", - ], f"Quantization mode {self.llm_config.base.preq_mode} is not compatible with SpinQuant." + ], f"Quantization mode {self.llm_config.base.preq_mode.value} is not compatible with SpinQuant." assert self.llm_config.base.preq_group_size, "preq_group_size must be specified" assert self.llm_config.model.dtype_override, "dtype_override must be specified" @@ -351,7 +351,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): } # Transform the output layer first if needed. - if self.llm_config.base.preq_mode == "8da4w_output_8da8w": + if self.llm_config.base.preq_mode.value == "8da4w_output_8da8w": from .source_transformation.pre_quantization import ( transform_output_linear_for_pre_quantization, ) @@ -359,14 +359,14 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): self.model_ = transform_output_linear_for_pre_quantization( module=self.model_, checkpoint=checkpoint, - dtype=mapping[self.llm_config.model.dtype_override], + dtype=mapping[self.llm_config.model.dtype_override.value], ) self.model_ = transform_linear_for_pre_quantization( self.model_, checkpoint, self.llm_config.base.preq_group_size, - mapping[self.llm_config.model.dtype_override], + mapping[self.llm_config.model.dtype_override.value], ) embedding_bit_width, embedding_group_size = None, None @@ -390,7 +390,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): self.model_ = transform_embedding_for_pre_quantization( self.model_, checkpoint, - mapping[self.llm_config.model.dtype_override], + mapping[self.llm_config.model.dtype_override.value], int(embedding_bit_width), embedding_group_size, ) diff --git a/examples/models/llama/source_transformation/attention.py b/examples/models/llama/source_transformation/attention.py index d5f065550d2..3c37cddee69 100644 --- a/examples/models/llama/source_transformation/attention.py +++ b/examples/models/llama/source_transformation/attention.py @@ -45,12 +45,10 @@ def __init__( self.register_buffer( f"past_k_caches_{i}", torch.zeros(cache_shape, dtype=dtype, device="cpu"), - persistent=False, ) self.register_buffer( f"past_v_caches_{i}", torch.zeros(cache_shape, dtype=dtype, device="cpu"), - persistent=False, ) def update( diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 1fb3d97a9c7..59823b533a3 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -493,12 +493,10 @@ def __init__( self.register_buffer( "past_k_caches", torch.zeros(cache_shape, dtype=dtype, device="cpu"), - persistent=False, ) self.register_buffer( "past_v_caches", torch.zeros(cache_shape, dtype=dtype, device="cpu"), - persistent=False, ) def update( diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index ce3b01b6d68..57b5796cbb3 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -47,29 +47,39 @@ def calculate_cache_key(layer_id: int, head_id: int) -> str: return f"l{layer_id},h{head_id}" @staticmethod - def apply_update(cache, update, pos, style, transpose=False): + def apply_update( + cache, update, pos, style, transpose=False, update_pos=0, update_len=None + ): """ After inference, update the cache state for next iteration. The runtime needs to implement the same operation. """ if style == "shift_pointer": if transpose: - update_len = update.size(-1) + update_len = update_len or update.size(-1) updated = torch.roll(cache, -update_len, -1) - updated[:, :, -update_len:] = update + updated[:, :, -update_len:] = update[ + :, :, update_pos : update_pos + update_len + ] else: - update_len = update.size(-2) + update_len = update_len or update.size(-2) updated = torch.roll(cache, -update_len, -2) - updated[:, -update_len:, :] = update + updated[:, -update_len:, :] = update[ + :, update_pos : update_pos + update_len, : + ] if style == "smart_mask": updated = torch.clone(cache) if transpose: - update_len = update.size(-1) - updated[:, :, pos : pos + update_len] = update + update_len = update_len or update.size(-1) + updated[:, :, pos : pos + update_len] = update[ + :, :, update_pos : update_pos + update_len + ] else: - update_len = update.size(-2) - updated[:, pos : pos + update_len, :] = update + update_len = update_len or update.size(-2) + updated[:, pos : pos + update_len, :] = update[ + :, update_pos : update_pos + update_len, : + ] return updated @@ -163,6 +173,164 @@ def unmask(self, new_unmasked_len): self.unmasked_len += new_unmasked_len +class StaticAttentionIOManager: + def __init__( + self, + config: ModelArgs, + input_len: int, + cache_len: int, + style: str = "shift_pointer", + mask_val: float = float("-inf"), + ): + self.mask = StaticAttentionMask( + input_len, cache_len, style=style, mask_val=mask_val + ) + + rope = Rope(config) + freqs = rope.get_freqs(None, config.max_seq_len) + self.freqs_cos = freqs[0] + self.freqs_sin = freqs[1] + + self.k_caches = { + StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros( + 1, cache_len, config.head_dim + ) + for layer_id in range(config.n_layers) + for head_id in range(config.n_kv_heads) + } + self.v_caches = { + StaticKVCache.calculate_cache_key(layer_id, head_id): torch.zeros( + 1, cache_len, config.head_dim + ) + for layer_id in range(config.n_layers) + for head_id in range(config.n_kv_heads) + } + + self.config = config + self.input_len = input_len + self.cache_len = cache_len + self.style = style + self.mask_val = mask_val + self.pos = 0 + self.cache_full = False + + def reset(self): + self.pos = 0 + self.cache_full = False + self.mask.reset() + + def prefill( + self, + model: Callable[..., Any], + tokens: List[int], + ) -> torch.Tensor: + if self.cache_full: + raise RuntimeError("KV cache is full.") + + self.mask.tensor[:, :, self.cache_len :] = torch.triu( + torch.full((1, self.input_len, self.input_len), self.mask_val), + diagonal=1, + ) + + logits = None + all_logits = None + for i in range(0, len(tokens), self.input_len): + logits = self._run_once(model, tokens[i : i + self.input_len])[0] + if self.config.generate_full_logits: + if all_logits is None: + all_logits = logits + else: + all_logits = torch.cat([all_logits, logits], dim=1) + + if self.config.generate_full_logits: + return all_logits[:, : len(tokens), :] + + return logits + + def decode( + self, + model: Callable[..., Any], + init_token: int, + n: int, + stop_tokens: Optional[List[int]] = None, + ): + if self.cache_full: + raise RuntimeError("KV cache is full.") + + self.mask.tensor[:, :, self.cache_len :] = torch.triu( + torch.full((1, self.input_len, self.input_len), self.mask_val), + diagonal=1, + ) + + stop_tokens = stop_tokens or [] + new_tokens = [init_token] + for _ in range(n): + y = self._run_once(model, new_tokens[-1:])[0] + new_tokens.append(y[:, :1, :].argmax().item()) + if new_tokens[-1] in stop_tokens: + break + + return new_tokens + + def _run_once( + self, + model: Callable[..., Any], + tokens: List[int], + non_padded_len: Optional[int] = None, + freqs_cos_override: Optional[torch.Tensor] = None, + freqs_sin_override: Optional[torch.Tensor] = None, + ): + n_tokens = len(tokens) + if n_tokens < self.input_len: + tokens += [0] * (self.input_len - n_tokens) + tokens = torch.tensor([tokens], dtype=torch.int32) + if freqs_cos_override is None: + freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len] + if freqs_sin_override is None: + freqs_sin_override = self.freqs_sin[self.pos : self.pos + self.input_len] + y, attn_updates = model( + tokens, + { + "mask": self.mask.tensor, + "freqs_cos_override": freqs_cos_override, + "freqs_sin_override": freqs_sin_override, + "in_cache_state": (self.k_caches, self.v_caches), + }, + ) + non_padded_len = non_padded_len or n_tokens + if self.pos + non_padded_len <= self.cache_len: + self._update_states(attn_updates, 0, non_padded_len) + else: + self.cache_full = True + + return y, attn_updates + + def _update_states(self, attn_updates, update_pos, update_len): + assert self.pos + update_len <= self.cache_len + + self.mask.unmask(update_len) + k_cache_updates, v_cache_updates = attn_updates["out_cache_state"] + for cache_id, update in k_cache_updates.items(): + self.k_caches[cache_id] = StaticKVCache.apply_update( + self.k_caches[cache_id], + update, + self.pos, + style=self.style, + update_pos=update_pos, + update_len=update_len, + ) + for cache_id, update in v_cache_updates.items(): + self.v_caches[cache_id] = StaticKVCache.apply_update( + self.v_caches[cache_id], + update, + self.pos, + style=self.style, + update_pos=update_pos, + update_len=update_len, + ) + self.pos += update_len + + class _Rope(nn.Module): def __init__(self, use_hf_rope): super().__init__() @@ -184,7 +352,7 @@ def forward( x_r, x_i = x[..., ::2], x[..., 1::2] x_out_r = x_r * freqs_cos - x_i * freqs_sin x_out_i = x_r * freqs_sin + x_i * freqs_cos - x_out = torch.cat([x_out_r, x_out_i], dim=-1) + x_out = torch.stack([x_out_r, x_out_i], dim=-1).flatten(2) return x_out @@ -210,6 +378,7 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope): self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5) self.attention_qkv_bias = config.attention_qkv_bias self.use_qk_norm = config.use_qk_norm + self.qk_norm_before_rope = config.qk_norm_before_rope self.use_conv2d = False self.wqs = nn.ModuleList( @@ -281,12 +450,17 @@ def from_conv2ds(ts): new_ks = from_conv2ds(new_ks) new_vs = from_conv2ds(new_vs) - if self.use_qk_norm: + if self.use_qk_norm and self.qk_norm_before_rope: new_qs = [self.q_norm(q) for q in new_qs] new_ks = [self.k_norm(k) for k in new_ks] new_qs = [self.rope(q, freqs_cos, freqs_sin) for q in new_qs] new_ks = [self.rope(k, freqs_cos, freqs_sin) for k in new_ks] + + if self.use_qk_norm and not self.qk_norm_before_rope: + new_qs = [self.q_norm(q) for q in new_qs] + new_ks = [self.k_norm(k) for k in new_ks] + all_ks = [] all_vs = [] for i in range(self.n_kv_heads): @@ -337,6 +511,7 @@ def load_weights_from_attention_mha(self, other: AttentionMHA): if other.use_qk_norm: self.use_qk_norm = True + self.qk_norm_before_rope = other.qk_norm_before_rope self.q_norm = torch.nn.RMSNorm(other.q_norm_fn.dim, other.q_norm_fn.eps) self.q_norm.load_state_dict(other.q_norm_fn.state_dict()) self.k_norm = torch.nn.RMSNorm(other.k_norm_fn.dim, other.k_norm_fn.eps) diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index 77b8be5d401..a6eac24db1f 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -1,12 +1,13 @@ import unittest import torch -from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions +from executorch.examples.models.llama.attention import AttentionMHA from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope from executorch.examples.models.llama.static_attention import ( StaticAttention, + StaticAttentionIOManager, StaticAttentionMask, StaticKVCache, ) @@ -29,6 +30,14 @@ def test(use_qk_norm, use_conv2d): rope = Rope(config) attn_mha = AttentionMHA(config, layer_id, rope).eval() static_attn = StaticAttention(config, layer_id, rope).eval() + if use_qk_norm: + with torch.no_grad(): + attn_mha.q_norm_fn.weight.copy_( + torch.rand(config.head_dim) * 0.2 + 0.9 + ) + attn_mha.k_norm_fn.weight.copy_( + torch.rand(config.head_dim) * 0.2 + 0.9 + ) static_attn.load_weights_from_attention_mha(attn_mha) if use_conv2d: static_attn.linear_to_conv2d() @@ -59,11 +68,15 @@ def test_hf_rope_without_cache(self): n_heads=4, n_kv_heads=2, max_seq_len=8, + use_qk_norm=True, use_hf_rope=True, ) layer_id = 0 rope = Rope(config) attn_mha = AttentionMHA(config, layer_id, rope).eval() + with torch.no_grad(): + attn_mha.q_norm_fn.weight.copy_(torch.rand(config.head_dim) * 0.2 + 0.9) + attn_mha.k_norm_fn.weight.copy_(torch.rand(config.head_dim) * 0.2 + 0.9) static_attn = StaticAttention(config, layer_id, rope).eval() static_attn.load_weights_from_attention_mha(attn_mha) @@ -171,8 +184,6 @@ def test_within_transformer(self): static_layer.attention.load_weights_from_attention_mha(mha_layer.attention) x = torch.randint(config.vocab_size, (1, config.max_seq_len)) - rope = Rope(config) - freqs_cos, freqs_sin = rope.get_freqs(None, config.max_seq_len) expected = mha_transformer(x) n_chunks = 3 @@ -180,53 +191,14 @@ def test_within_transformer(self): cache_len = config.max_seq_len - chunk_len def test_with_style(style): - mask = StaticAttentionMask(chunk_len, cache_len, style=style) - mask.tensor[:, :, cache_len:] = torch.triu( - torch.full((1, chunk_len, chunk_len), float("-inf")), - diagonal=1, - ) - k_caches = { - StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( - 1, cache_len, config.head_dim - ) - for layer_id in range(config.n_layers) - for i in range(config.n_kv_heads) - } - v_caches = { - StaticKVCache.calculate_cache_key(layer_id, i): torch.zeros( - 1, cache_len, config.head_dim - ) - for layer_id in range(config.n_layers) - for i in range(config.n_kv_heads) - } + mgr = StaticAttentionIOManager(config, chunk_len, cache_len, style=style) ys = [] for i in range(n_chunks): - y_i, attn_update = static_transformer( - x[:, i * chunk_len : (i + 1) * chunk_len], - attn_options=ForwardOptions( - mask=mask.tensor, - freqs_cos_override=freqs_cos[ - i * chunk_len : (i + 1) * chunk_len - ], - freqs_sin_override=freqs_sin[ - i * chunk_len : (i + 1) * chunk_len - ], - in_cache_state=(k_caches, v_caches), - out_cache_state=({}, {}), - ), + y_i = mgr.prefill( + static_transformer, + x[0][i * chunk_len : (i + 1) * chunk_len].tolist(), ) ys.append(y_i) - mask.unmask(chunk_len) - k_cache_updates, v_cache_updates = attn_update["out_cache_state"] - if i < n_chunks - 1: - for cache_id, update in k_cache_updates.items(): - k_caches[cache_id] = StaticKVCache.apply_update( - k_caches[cache_id], update, pos=chunk_len * i, style=style - ) - for cache_id, update in v_cache_updates.items(): - v_caches[cache_id] = StaticKVCache.apply_update( - v_caches[cache_id], update, pos=chunk_len * i, style=style - ) self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all()) diff --git a/examples/models/llama2/README.md b/examples/models/llama2/README.md index 615ad3948fc..21f761b7f71 100644 --- a/examples/models/llama2/README.md +++ b/examples/models/llama2/README.md @@ -37,7 +37,7 @@ You can export and run the original Llama 2 7B model. 3. Export model and generate `.pte` file: ``` - python -m examples.models.llama.export_llama --checkpoint --params -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 + python -m extension.llm.export.export_llm base.checkpoint= base.params= model.use_kv_cache=True model.use_sdpa_with_kv_cache=True backend.xnnpack.enabled=True quantization.qmode="8da4w" quantization.group_size=128 model.dtype_override="fp32" ``` 4. Create tokenizer.bin. ``` diff --git a/examples/models/moshi/mimi/install_requirements.sh b/examples/models/moshi/mimi/install_requirements.sh index ef915ca7eb2..828bcac0abb 100755 --- a/examples/models/moshi/mimi/install_requirements.sh +++ b/examples/models/moshi/mimi/install_requirements.sh @@ -7,7 +7,7 @@ set -x -pip install -U moshi +pip install moshi==0.2.4 pip install bitsandbytes soundfile # Run llama2/install requirements for torchao deps SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) diff --git a/examples/models/phi_4_mini/README.md b/examples/models/phi_4_mini/README.md index a23e4f49638..d168d54226e 100644 --- a/examples/models/phi_4_mini/README.md +++ b/examples/models/phi_4_mini/README.md @@ -7,9 +7,9 @@ Phi-4-mini uses the same example code as Llama, while the checkpoint, model para All commands for exporting and running Llama on various backends should also be applicable to Phi-4-mini, by swapping the following args: ``` ---model phi_4_mini ---params examples/models/phi-4-mini/config.json ---checkpoint +base.model_class="phi_4_mini" +base.params="examples/models/phi-4-mini/config.json" +base.checkpoint= ``` ### Generate the Checkpoint @@ -32,17 +32,17 @@ Export to XNNPack, no quantization: # Set these paths to point to the downloaded files PHI_CHECKPOINT=path/to/checkpoint.pth -python -m examples.models.llama.export_llama \ - --model phi_4_mini \ - --checkpoint "${PHI_CHECKPOINT=path/to/checkpoint.pth:?}" \ - --params examples/models/phi-4-mini/config.json \ - -kv \ - --use_sdpa_with_kv_cache \ - -d fp32 \ - -X \ - --metadata '{"get_bos_id":151643, "get_eos_ids":[151643]}' \ - --output_name="phi-4-mini.pte" - --verbose +python -m extension.llm.export.export_llm \ + base.model_class="phi_4_mini" \ + base.checkpoint="${PHI_CHECKPOINT=path/to/checkpoint.pth:?}" \ + base.params="examples/models/phi-4-mini/config.json" \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + model.dtype_override="fp32" \ + backend.xnnpack.enabled=True \ + base.metadata='"{\"get_bos_id\":151643, \"get_eos_ids\":[151643]}"' \ + export.output_name="phi-4-mini.pte" \ + debug.verbose=True ``` Run using the executor runner: diff --git a/examples/models/qwen2_5/README.md b/examples/models/qwen2_5/README.md index 9bf791a35ed..57784169ece 100644 --- a/examples/models/qwen2_5/README.md +++ b/examples/models/qwen2_5/README.md @@ -7,9 +7,9 @@ Qwen 2.5 uses the same example code as Llama, while the checkpoint, model params All commands for exporting and running Llama on various backends should also be applicable to Qwen 2.5, by swapping the following args: ``` ---model qwen2_5 ---params examples/models/qwen2_5/1_5b_config.json ---checkpoint +base.model_class="qwen2_5" +base.params="examples/models/qwen2_5/1_5b_config.json" +base.checkpoint= ``` ### Generate the Checkpoint @@ -32,17 +32,17 @@ Export to XNNPack, no quantization: # Set these paths to point to the downloaded files QWEN_CHECKPOINT=path/to/checkpoint.pth -python -m examples.models.llama.export_llama \ - --model "qwen2_5" \ - --checkpoint "${QWEN_CHECKPOINT:?}" \ - --params examples/models/qwen2_5/1_5b_config.json \ - -kv \ - --use_sdpa_with_kv_cache \ - -d fp32 \ - -X \ - --metadata '{"get_bos_id":151643, "get_eos_ids":[151643]}' \ - --output_name="qwen2_5-1_5b.pte" - --verbose +python -m extension.llm.export.export_llm \ + base.model_class="qwen2_5" \ + base.checkpoint="${QWEN_CHECKPOINT:?}" \ + base.params="examples/models/qwen2_5/1_5b_config.json" \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + model.dtype_override="fp32" \ + backend.xnnpack.enabled=True \ + base.metadata='"{\"get_bos_id\":151643, \"get_eos_ids\":[151643]}"' \ + export.output_name="qwen2_5-1_5b.pte" \ + debug.verbose=True ``` Run using the executor runner: diff --git a/examples/models/qwen3/README.md b/examples/models/qwen3/README.md index a589d27c19d..d31d491adf2 100644 --- a/examples/models/qwen3/README.md +++ b/examples/models/qwen3/README.md @@ -7,8 +7,8 @@ Qwen 3 uses the same example code as our optimized Llama model, while the checkp All commands for exporting and running Llama on various backends should also be applicable to Qwen 3, by swapping the following args: ``` ---model [qwen3-0.6b,qwen3-1_7b,qwen3-4b] ---params [examples/models/qwen3/0_6b_config.json,examples/models/qwen3/1_7b_config.json,examples/models/qwen3/4b_config.json] +base.model_class=[qwen3-0_6b,qwen3-1_7b,qwen3-4b] +base.params=[examples/models/qwen3/0_6b_config.json,examples/models/qwen3/1_7b_config.json,examples/models/qwen3/4b_config.json] ``` ### Example export @@ -16,50 +16,50 @@ Here is a basic example for exporting Qwen 3, although please refer to the Llama Export 0.6b to XNNPack, quantized with 8da4w: ``` -python -m examples.models.llama.export_llama \ - --model qwen3-0_6b \ - --params examples/models/qwen3/0_6b_config.json \ - -kv \ - --use_sdpa_with_kv_cache \ - -d fp32 \ - -X \ - --xnnpack-extended-ops \ - -qmode 8da4w \ - --metadata '{"get_bos_id": 151644, "get_eos_ids":[151645]}' \ - --output_name="qwen3-0_6b.pte" \ - --verbose +python -m extension.llm.export.export_llm \ + base.model_class="qwen3-0_6b" \ + base.params="examples/models/qwen3/0_6b_config.json" \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + model.dtype_override="fp32" \ + backend.xnnpack.enabled=True \ + backend.xnnpack.extended_ops=True \ + quantization.qmode="8da4w" \ + base.metadata='"{\"get_bos_id\": 151644, \"get_eos_ids\":[151645]}"' \ + export.output_name="qwen3-0_6b.pte" \ + debug.verbose=True ``` Export 1.7b to XNNPack, quantized with 8da4w: ``` -python -m examples.models.llama.export_llama \ - --model qwen3-1_7b \ - --params examples/models/qwen3/1_7b_config.json \ - -kv \ - --use_sdpa_with_kv_cache \ - -d fp32 \ - -X \ - --xnnpack-extended-ops \ - -qmode 8da4w \ - --metadata '{"get_bos_id": 151644, "get_eos_ids":[151645]}' \ - --output_name="qwen3-1_7b.pte" \ - --verbose +python -m extension.llm.export.export_llm \ + base.model_class="qwen3-1_7b" \ + base.params="examples/models/qwen3/1_7b_config.json" \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + model.dtype_override="fp32" \ + backend.xnnpack.enabled=True \ + backend.xnnpack.extended_ops=True \ + quantization.qmode="8da4w" \ + base.metadata='"{\"get_bos_id\": 151644, \"get_eos_ids\":[151645]}"' \ + export.output_name="qwen3-1_7b.pte" \ + debug.verbose=True ``` Export 4b to XNNPack, quantized with 8da4w: ``` -python -m examples.models.llama.export_llama \ - --model qwen3-4b \ - --params examples/models/qwen3/4b_config.json \ - -kv \ - --use_sdpa_with_kv_cache \ - -d fp32 \ - -X \ - --xnnpack-extended-ops \ - -qmode 8da4w \ - --metadata '{"get_bos_id": 151644, "get_eos_ids":[151645]}' \ - --output_name="qwen3-4b.pte" \ - --verbose +python -m extension.llm.export.export_llm \ + base.model_class="qwen3-4b" \ + base.params="examples/models/qwen3/4b_config.json" \ + model.use_kv_cache=True \ + model.use_sdpa_with_kv_cache=True \ + model.dtype_override="fp32" \ + backend.xnnpack.enabled=True \ + backend.xnnpack.extended_ops=True \ + quantization.qmode="8da4w" \ + base.metadata='"{\"get_bos_id\": 151644, \"get_eos_ids\":[151645]}"' \ + export.output_name="qwen3-4b.pte" \ + debug.verbose=True ``` ### Example run diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt index 4b0e6b2d3a2..757c7518f0c 100644 --- a/examples/qualcomm/CMakeLists.txt +++ b/examples/qualcomm/CMakeLists.txt @@ -35,7 +35,10 @@ find_package(gflags REQUIRED) set(_common_compile_options -Wno-deprecated-declarations -fPIC) # Let files say "include ". -set(_common_include_directories ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/extension/llm/tokenizers/third-party/json/single_include) +set(_common_include_directories + ${EXECUTORCH_ROOT}/.. + ${EXECUTORCH_ROOT}/extension/llm/tokenizers/third-party/json/single_include +) # # The `__srcs` lists are defined by including ${EXECUTORCH_SRCS_FILE}. @@ -72,20 +75,11 @@ target_include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/tokenizers/third-party/llama.cpp-unicode/src ) -# find RE2 for tokenizer -set(ABSL_ENABLE_INSTALL ON) -set(ABSL_PROPAGATE_CXX_STD ON) -set(_pic_flag ${CMAKE_POSITION_INDEPENDENT_CODE}) -set(CMAKE_POSITION_INDEPENDENT_CODE ON) -add_subdirectory( - ${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/tokenizers/third-party/abseil-cpp - ${CMAKE_CURRENT_BINARY_DIR}/abseil-cpp -) +# add tokenizers add_subdirectory( - ${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/tokenizers/third-party/re2 - ${CMAKE_CURRENT_BINARY_DIR}/re2 + ${EXECUTORCH_ROOT}/extension/llm/tokenizers + ${CMAKE_CURRENT_BINARY_DIR}/../../extension/llm/tokenizers ) -set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag}) # build qnn_executor_runner add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/executor_runner) diff --git a/examples/qualcomm/oss_scripts/albert.py b/examples/qualcomm/oss_scripts/albert.py new file mode 100644 index 00000000000..6af554655f1 --- /dev/null +++ b/examples/qualcomm/oss_scripts/albert.py @@ -0,0 +1,162 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import getpass +import json +import logging +import os +from multiprocessing.connection import Client + +import evaluate +import numpy as np +import torch +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, +) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype + +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + get_masked_language_model_dataset, + make_output_dir, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, +) +from transformers import AlbertConfig, AutoModelForMaskedLM, AutoTokenizer + + +def main(args): + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + + os.makedirs(args.artifact, exist_ok=True) + data_size = 100 + + model_name = "albert/albert-base-v2" + tokenizer = AutoTokenizer.from_pretrained(model_name, hidden_act="gelu") + + if args.ci: + random_ids = torch.randint(low=0, high=100, size=(1, 100), dtype=torch.int32) + attention_mask = torch.ones((1, 100), dtype=torch.float32) + inputs = [ + ( + random_ids, + attention_mask, + ) + ] + logging.warning( + "This option is for CI to verify the export flow. It uses random input and will result in poor accuracy." + ) + else: + inputs, targets, input_list = get_masked_language_model_dataset( + args.dataset, tokenizer, data_size + ) + + config = AlbertConfig.from_pretrained(model_name) + config.hidden_act = "gelu" + module = AutoModelForMaskedLM.from_pretrained(model_name, config=config).eval() + pte_filename = "albert_qnn_q16" + + # lower to QNN + passes_job = get_capture_program_passes() + build_executorch_binary( + module, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + dataset=inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_16a16w, + passes_job=passes_job, + shared_buffer=args.shared_buffer, + ) + + if args.compile_only: + return + + workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}" + pte_path = f"{args.artifact}/{pte_filename}.pte" + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=pte_path, + workspace=workspace, + device_id=args.device, + host_id=args.host, + soc_model=args.model, + ) + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + # accuracy analysis + adb.push(inputs=inputs, input_list=input_list) + adb.execute() + adb.pull(output_path=args.artifact) + # since the original nn.Module could not perform well on this task either + # we only measure the relative accuracy here + goldens, predictions, nominal_predictions = [], [], [] + for i in range(len(inputs)): + indices = [i for i, x in enumerate(targets[i]) if x != -100] + goldens.extend(targets[i][indices].tolist()) + nominal_prediction = module(*inputs[i]) + nominal_predictions.extend( + nominal_prediction.logits.argmax(axis=-1)[0, indices].tolist() + ) + prediction = ( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + .reshape([1, inputs[0][0].shape[1], -1]) + .argmax(axis=-1) + ) + predictions.extend(prediction[0, indices].tolist()) + + metric = evaluate.load("accuracy") + nominal_results = metric.compute( + predictions=nominal_predictions, references=goldens + ) + device_results = metric.compute(predictions=predictions, references=goldens) + result = device_results["accuracy"] / nominal_results["accuracy"] + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"accuracy": result})) + else: + print(f"accuracy: {device_results}") + print(f"accuracy with nn.Module as golden: {result}") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts and output by this example. Default ./albert", + default="./albert", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation text. " + "e.g. --dataset wikisent2.txt " + "for https://www.kaggle.com/datasets/mikeortman/wikipedia-sentences" + ), + type=str, + required=False, + ) + + args = parser.parse_args() + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/oss_scripts/bert.py b/examples/qualcomm/oss_scripts/bert.py new file mode 100644 index 00000000000..96c7826d89c --- /dev/null +++ b/examples/qualcomm/oss_scripts/bert.py @@ -0,0 +1,149 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import getpass +import json +import logging +import os +from multiprocessing.connection import Client + +import evaluate +import numpy as np +import torch + +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, +) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype + +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + get_masked_language_model_dataset, + make_output_dir, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, +) +from transformers import AutoModelForMaskedLM, AutoTokenizer + + +def main(args): + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + + os.makedirs(args.artifact, exist_ok=True) + data_size = 100 + + tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + if args.ci: + random_ids = torch.randint(low=0, high=100, size=(1, 100), dtype=torch.int32) + attention_mask = torch.ones((1, 100), dtype=torch.float32) + inputs = [ + ( + random_ids, + attention_mask, + ) + ] + logging.warning( + "This option is for CI to verify the export flow. It uses random input and will result in poor accuracy." + ) + else: + inputs, targets, input_list = get_masked_language_model_dataset( + args.dataset, tokenizer, data_size + ) + module = AutoModelForMaskedLM.from_pretrained( + "google-bert/bert-base-uncased" + ).eval() + pte_filename = "bert_qnn_q16" + + # lower to QNN + passes_job = get_capture_program_passes() + build_executorch_binary( + module, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + dataset=inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_16a8w, + passes_job=passes_job, + shared_buffer=args.shared_buffer, + ) + + if args.compile_only: + return + + workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}" + pte_path = f"{args.artifact}/{pte_filename}.pte" + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=pte_path, + workspace=workspace, + device_id=args.device, + host_id=args.host, + soc_model=args.model, + ) + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + # accuracy analysis + adb.push(inputs=inputs, input_list=input_list) + adb.execute() + adb.pull(output_path=args.artifact) + goldens, predictions = [], [] + for i in range(len(inputs)): + indices = [i for i, x in enumerate(targets[i]) if x != -100] + goldens.extend(targets[i][indices].tolist()) + prediction = ( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + .reshape([1, inputs[0][0].shape[1], -1]) + .argmax(axis=-1) + ) + predictions.extend(prediction[0, indices].tolist()) + + metric = evaluate.load("accuracy") + results = metric.compute(predictions=predictions, references=goldens) + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"accuracy": results["accuracy"]})) + else: + print(f"accuracy: {results['accuracy']}") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts and output by this example. Default ./bert", + default="./bert", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation text. " + "e.g. --dataset wikisent2.txt " + "for https://www.kaggle.com/datasets/mikeortman/wikipedia-sentences" + ), + type=str, + required=False, + ) + + args = parser.parse_args() + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/oss_scripts/distilbert.py b/examples/qualcomm/oss_scripts/distilbert.py new file mode 100644 index 00000000000..2863a653200 --- /dev/null +++ b/examples/qualcomm/oss_scripts/distilbert.py @@ -0,0 +1,149 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import getpass +import json +import logging +import os +from multiprocessing.connection import Client + +import evaluate +import numpy as np +import torch + +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, +) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype + +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + get_masked_language_model_dataset, + make_output_dir, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, +) +from transformers import AutoModelForMaskedLM, AutoTokenizer + + +def main(args): + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + + os.makedirs(args.artifact, exist_ok=True) + data_size = 100 + + tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased") + if args.ci: + random_ids = torch.randint(low=0, high=100, size=(1, 100), dtype=torch.int32) + attention_mask = torch.ones((1, 100), dtype=torch.float32) + inputs = [ + ( + random_ids, + attention_mask, + ) + ] + logging.warning( + "This option is for CI to verify the export flow. It uses random input and will result in poor accuracy." + ) + else: + inputs, targets, input_list = get_masked_language_model_dataset( + args.dataset, tokenizer, data_size + ) + module = AutoModelForMaskedLM.from_pretrained( + "distilbert/distilbert-base-uncased" + ).eval() + pte_filename = "distilbert_qnn_q16" + + # lower to QNN + passes_job = get_capture_program_passes() + build_executorch_binary( + module, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + dataset=inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_16a8w, + passes_job=passes_job, + shared_buffer=args.shared_buffer, + ) + + if args.compile_only: + return + + workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}" + pte_path = f"{args.artifact}/{pte_filename}.pte" + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=pte_path, + workspace=workspace, + device_id=args.device, + host_id=args.host, + soc_model=args.model, + ) + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + # accuracy analysis + adb.push(inputs=inputs, input_list=input_list) + adb.execute() + adb.pull(output_path=args.artifact) + goldens, predictions = [], [] + for i in range(len(inputs)): + indices = [i for i, x in enumerate(targets[i]) if x != -100] + goldens.extend(targets[i][indices].tolist()) + prediction = ( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + .reshape([1, inputs[0][0].shape[1], -1]) + .argmax(axis=-1) + ) + predictions.extend(prediction[0, indices].tolist()) + + metric = evaluate.load("accuracy") + results = metric.compute(predictions=predictions, references=goldens) + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"accuracy": results["accuracy"]})) + else: + print(f"accuracy: {results['accuracy']}") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts and output by this example. Default ./distilbert", + default="./distilbert", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation text. " + "e.g. --dataset wikisent2.txt " + "for https://www.kaggle.com/datasets/mikeortman/wikipedia-sentences" + ), + type=str, + required=False, + ) + + args = parser.parse_args() + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/oss_scripts/eurobert.py b/examples/qualcomm/oss_scripts/eurobert.py new file mode 100644 index 00000000000..97e70428e01 --- /dev/null +++ b/examples/qualcomm/oss_scripts/eurobert.py @@ -0,0 +1,187 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import os +from multiprocessing.connection import Client + +import evaluate +import numpy as np +import torch +import transformers +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, +) + +from executorch.backends.qualcomm.quantizer.custom_annotation import annotate_eurobert +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype + +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + get_masked_language_model_dataset, + make_output_dir, + make_quantizer, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, +) +from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer + +TRANSFORMERS_VERSION = "4.48.0" + + +def main(args): + assert ( + transformers.__version__ >= TRANSFORMERS_VERSION + ), f"Please ensure transformers version >= {TRANSFORMERS_VERSION}, current version is {transformers.__version__}" + + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + + os.makedirs(args.artifact, exist_ok=True) + + if not args.compile_only and args.device is None: + raise RuntimeError( + "device serial is required if not compile only. " + "Please specify a device serial by -s/--device argument." + ) + + module_id = "EuroBERT/EuroBERT-210m" + tokenizer = AutoTokenizer.from_pretrained(module_id) + model = AutoModelForMaskedLM.from_pretrained( + module_id, trust_remote_code=True + ).eval() + config = AutoConfig.from_pretrained(module_id, trust_remote_code=True) + + def replace_rms_norm_with_native_rms_norm(module: torch.nn.Module): + for name, child in module.named_children(): + if child._get_name() == "EuroBertRMSNorm": + rms_norm = torch.nn.RMSNorm( + [config.hidden_size], eps=child.variance_epsilon + ) + rms_norm.weight = child.weight + setattr( + module, + name, + rms_norm, + ) + else: + replace_rms_norm_with_native_rms_norm(child) + return module + + replace_rms_norm_with_native_rms_norm(model) + + data_size = 100 + if args.ci: + random_ids = torch.randint(low=0, high=100, size=(1, 100), dtype=torch.int32) + attention_mask = torch.ones((1, 100), dtype=torch.float32) + inputs = [ + ( + random_ids, + attention_mask, + ) + ] + logging.warning( + "This option is for CI to verify the export flow. It uses random input and will result in poor accuracy." + ) + else: + inputs, targets, input_list = get_masked_language_model_dataset( + args.dataset, tokenizer, data_size + ) + + pte_filename = "eurobert_qnn_q16" + + # lower to QNN + passes_job = get_capture_program_passes() + quantizer = make_quantizer( + quant_dtype=QuantDtype.use_16a16w, + ) + quantizer.add_custom_quant_annotations((annotate_eurobert,)) + with torch.no_grad(): + build_executorch_binary( + model, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + dataset=inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + custom_quantizer=quantizer, + passes_job=passes_job, + shared_buffer=args.shared_buffer, + ) + + if args.compile_only: + return + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + ) + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + # accuracy analysis + adb.push(inputs=inputs, input_list=input_list) + adb.execute() + adb.pull(output_path=args.artifact) + goldens, predictions = [], [] + for i in range(len(inputs)): + indices = [i for i, x in enumerate(targets[i]) if x != -100] + goldens.extend(targets[i][indices].tolist()) + + prediction = ( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + .reshape([1, inputs[0][0].shape[1], -1]) + .argmax(axis=-1) + ) + predictions.extend(prediction[0, indices].tolist()) + metric = evaluate.load("accuracy") + results = metric.compute(predictions=predictions, references=goldens) + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"accuracy": results["accuracy"]})) + else: + print(f"accuracy: {results['accuracy']}") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts and output by this example. Default ./eurobert", + default="./eurobert", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation text. " + "e.g. --dataset wikisent2.txt " + "for https://www.kaggle.com/datasets/mikeortman/wikipedia-sentences" + ), + type=str, + required=False, + ) + + args = parser.parse_args() + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt index 246a47fceba..dadf51bf298 100644 --- a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt @@ -6,16 +6,16 @@ # model sharding with custom op set(CUSTOM_OP_SRCS_FILE - "${EXECUTORCH_SOURCE_DIR}/extension/llm/custom_ops/op_fallback.cpp" + "${EXECUTORCH_SOURCE_DIR}/extension/llm/custom_ops/op_fallback.cpp" ) +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) + add_library(custom_ops ${CUSTOM_OP_SRCS_FILE}) target_include_directories(custom_ops PUBLIC "${_common_include_directories}") target_include_directories( custom_ops PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/../../include" ) -target_link_libraries( - custom_ops PUBLIC full_portable_ops_lib -) +target_link_libraries(custom_ops PUBLIC full_portable_ops_lib) target_link_options_shared_lib(custom_ops) # preprocess qnn runner src files for llama @@ -44,17 +44,15 @@ list( ${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.h ) -list( - APPEND - _llama_runner__srcs - ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizers/src/tiktoken.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/../../../models/llama/tokenizer/llama_tiktoken.cpp -) +list(APPEND _llama_runner__srcs) # build qnn llama runner add_executable(qnn_llama_runner ${_llama_runner__srcs}) target_include_directories( - qnn_llama_runner PUBLIC ${_common_include_directories} ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizers/include + qnn_llama_runner + PUBLIC + ${_common_include_directories} + ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizers/include ) target_link_options_shared_lib(quantized_ops_lib) @@ -68,14 +66,12 @@ target_link_libraries( extension_module extension_tensor gflags - re2::re2 custom_ops quantized_ops_lib quantized_kernels + tokenizers ) -target_compile_options( - qnn_llama_runner PUBLIC ${_common_compile_options} -) +target_compile_options(qnn_llama_runner PUBLIC ${_common_compile_options}) set_target_properties( qnn_llama_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" ) diff --git a/examples/qualcomm/oss_scripts/llama/TARGETS b/examples/qualcomm/oss_scripts/llama/TARGETS index aee00c44c76..9c5dd1ceaf9 100644 --- a/examples/qualcomm/oss_scripts/llama/TARGETS +++ b/examples/qualcomm/oss_scripts/llama/TARGETS @@ -49,6 +49,9 @@ python_binary( name = "eval_llama_qnn", srcs = ["eval_llama_qnn.py"], main_function = "executorch.examples.qualcomm.oss_scripts.llama.eval_llama_qnn.main", + preload_deps = [ + "//executorch/extension/llm/custom_ops:model_sharding_py", + ], deps = [ ":llama_lib", "//executorch/examples/models/llama:eval_library", diff --git a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py index dd8fc704032..1105ac0ef82 100644 --- a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py +++ b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py @@ -5,49 +5,77 @@ # LICENSE file in the root directory of this source tree. import argparse -import logging import copy import json -import torch -from lm_eval.evaluator import simple_evaluate -from typing import List, Optional, Tuple +import logging +import sys + +from typing import List, Tuple import torch import torch.nn as nn +from executorch.backends.qualcomm.quantizer.custom_annotation import ( + annotate_linear_16a8w_in_affine_layer, + annotate_matmul_16a8w, +) + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d from executorch.examples.models.llama.eval_llama_lib import ( build_args_parser, - GraphModuleEvalWrapper + GraphModuleEvalWrapper, ) -from pytorch_tokenizers import get_tokenizer +from executorch.examples.models.llama.source_transformation.quantize import ( + get_quant_embedding_transform, +) + +from executorch.examples.qualcomm.oss_scripts.llama.llama import calibrate from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import ( - LlamaModel, - ModelArgs, + LlamaModel, + ModelArgs, ) +from executorch.examples.qualcomm.utils import make_quantizer + +from lm_eval.evaluator import simple_evaluate + +from pytorch_tokenizers import get_tokenizer + +from torchao.quantization.pt2e import MinMaxObserver +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + +sys.setrecursionlimit(4096) +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + class WrappedLlamaModel(nn.Module): - def __init__(self, model, use_kv_cache=False, max_seq_len=512, device='cuda'): + def __init__( + self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda" + ): super(WrappedLlamaModel, self).__init__() self.model = model self.max_seq_len = max_seq_len self.use_kv_cache = use_kv_cache self.device = device + self.atten_mask = atten_mask - def forward(self, + def forward( + self, tokens: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, *args, ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: # Pad input if necessary, since LlamaModel requires static shape if tokens.shape[1] != self.max_seq_len: - tokens = torch.nn.functional.pad(tokens, (self.max_seq_len - tokens.shape[1],0)) - atten_mask = self.model.get_example_inputs(self.use_kv_cache)[1].to(device=self.device).to(dtype=torch.bfloat16) - return self.model.forward(tokens, atten_mask, input_pos, *args) - + tokens = torch.nn.functional.pad( + tokens, (0, self.max_seq_len - tokens.shape[1]) + ) + return self.model.forward(tokens, self.atten_mask) def gen_eval_wrapper(model_name, args): @@ -66,7 +94,13 @@ def gen_eval_wrapper(model_name, args): ) config = prefill_config use_i64_token = args.embedding_quantize is not None - model = LlamaModel(config, ar_len=args.prefill_ar_len, output_new_cache_only=True, output_cache=False, use_i64_token=use_i64_token) + model = LlamaModel( + config, + ar_len=args.prefill_ar_len, + output_new_cache_only=True, + output_cache=False, + use_i64_token=use_i64_token, + ) state_dict = torch.load( args.checkpoint, weights_only=True, map_location=args.device, mmap=True ) @@ -109,12 +143,69 @@ def permute(w, heads): layer.feed_forward.prepare_feedfoward_conv() model.to(dtype=torch.bfloat16) - model.to(args.device) + model.to(device=args.device) + + tokens, atten_mask = model.get_example_inputs(use_kv_cache=False) + tokens = tokens.to(device=args.device) + atten_mask = atten_mask.to(device=args.device) + atten_mask = atten_mask.to(dtype=torch.bfloat16) + inputs = (tokens, atten_mask) + + if args.embedding_quantize: + model = get_quant_embedding_transform( + embedding_quantize=args.embedding_quantize + )(model) + + model = convert_linear_to_conv2d(model) + + if args.ptq: + quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") + + custom_annotations = (annotate_matmul_16a8w,) + if args.llama_model == "stories110m": + custom_annotations = custom_annotations + ( + annotate_linear_16a8w_in_affine_layer, + ) + quantizer = make_quantizer( + quant_dtype=quant_dtype, + per_channel_conv=True, + per_channel_linear=True, + act_observer=MinMaxObserver, + ) + quantizer.add_custom_quant_annotations(custom_annotations) + + model.has_quant_io = True + + with torch.no_grad(): + model = torch.export.export(model, inputs, strict=True).module() + if quant_dtype == QuantDtype.use_16a4w_block: + conv_nodes = [n for n in model.graph.nodes if "conv" in n.name] + block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes} + quantizer.set_block_size_map(block_size_map) + + model = prepare_pt2e(model, quantizer) + + logging.info("Quantizing the model...") + + calibrate( + inputs, + "Once upon a time", + model, + tokenizer=tokenizer, + ar_len=args.prefill_ar_len, + max_seq_len=args.max_seq_len, + kv_updater=None, + use_i64_token=use_i64_token, + ) + + model = convert_pt2e(model) - wrapped_model = WrappedLlamaModel(model, args.use_kv_cache, args.max_seq_length, args.device) + model = WrappedLlamaModel( + model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device + ) return GraphModuleEvalWrapper( - model=wrapped_model, + model=model, tokenizer=tokenizer, max_seq_length=args.calibration_seq_length, use_kv_cache=args.use_kv_cache, @@ -123,7 +214,6 @@ def permute(w, heads): ) - def eval_llama( model_name: str, args: argparse.Namespace, @@ -156,6 +246,7 @@ def main() -> None: modelname = "llama2" parser = build_args_parser() args = parser.parse_args() + args.llama_model = "llama3_2" # Overrides this arg, because evaluation requires full logits. args.generate_full_logits = True @@ -166,7 +257,14 @@ def main() -> None: args.use_kv_cache = False args.prefill_ar_len = args.max_seq_length - args.device = 'cuda' if torch.cuda.is_available() else 'cpu' + # To do fewer samples for faster evaluation + args.limit = 0.1 + # args.samples = {'wikitext': list(range(1))} + + args.device = "cuda" if torch.cuda.is_available() else "cpu" + torch.set_default_device(args.device) + + args.ptq = "8a8w" eval_llama(modelname, args) diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 33482090b28..99f346eccbc 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -398,7 +398,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): self.llama_graph_module, self.inputs, strict=True ).module() - if QuantDtype == QuantDtype.use_16a4w_block: + if quant_dtype == QuantDtype.use_16a4w_block: conv_nodes = [ n for n in fx_graph_module.graph.nodes if "conv" in n.name ] diff --git a/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt b/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt index 4e44a1599b1..2a13bbe861c 100644 --- a/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt +++ b/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt @@ -6,6 +6,8 @@ # preprocess qaihub runner src files for llama2,3 set(_qaihub_llama_runner__srcs ${_llama_runner__srcs}) +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) + list(TRANSFORM _qaihub_llama_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/") list(FILTER _qaihub_llama_runner__srcs EXCLUDE REGEX ".*(/runner/).*") list( @@ -26,13 +28,11 @@ list(PREPEND _qaihub_llama2_7b_runner__srcs # build qaihub llama2 7b runner add_executable(qaihub_llama2_7b_runner ${_qaihub_llama2_7b_runner__srcs}) + target_include_directories( - qaihub_llama2_7b_runner PUBLIC - ${_common_include_directories} - ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizers/include - ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizers/third-party/json/single_include - ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizers/third-party/llama.cpp-unicode/include - ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizers/third-party/llama.cpp-unicode/src + qaihub_llama2_7b_runner + PUBLIC ${_common_include_directories} + ${EXECUTORCH_ROOT}/extension/llm/tokenizers/include ) target_link_libraries( qaihub_llama2_7b_runner @@ -43,7 +43,7 @@ target_link_libraries( extension_module extension_tensor gflags - re2::re2 + tokenizers ) target_compile_options( qaihub_llama2_7b_runner PUBLIC ${_common_compile_options} @@ -62,25 +62,13 @@ list(PREPEND _qaihub_llama3_8b_runner__srcs # Adding a compile option to differentiate llama2 with llama3 logic list(APPEND _common_compile_options -DQAIHUB_LLAMA3_RUNNER) -list( - APPEND _qaihub_llama3_8b_runner__srcs - ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizers/src/tiktoken.cpp -) -list( - APPEND - _qaihub_llama3_8b_runner__srcs - ${CMAKE_CURRENT_SOURCE_DIR}/../../../models/llama/tokenizer/llama_tiktoken.cpp -) - # build qaihub llama3 8b runner add_executable(qaihub_llama3_8b_runner ${_qaihub_llama3_8b_runner__srcs}) target_include_directories( - qaihub_llama3_8b_runner PUBLIC - ${_common_include_directories} + qaihub_llama3_8b_runner + PUBLIC + ${_common_include_directories} ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizers/include - ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizers/third-party/json/single_include - ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizers/third-party/llama.cpp-unicode/include - ${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizers/third-party/llama.cpp-unicode/src ) target_link_libraries( @@ -92,7 +80,7 @@ target_link_libraries( extension_module extension_tensor gflags - re2::re2 + tokenizers ) target_compile_options( qaihub_llama3_8b_runner PUBLIC ${_common_compile_options} diff --git a/examples/selective_build/CMakeLists.txt b/examples/selective_build/CMakeLists.txt index 5a824f51ff9..39b212b16fb 100644 --- a/examples/selective_build/CMakeLists.txt +++ b/examples/selective_build/CMakeLists.txt @@ -33,7 +33,7 @@ if(NOT CMAKE_CXX_STANDARD) # Can't set to 11 due to executor_runner.cpp make_unique endif() -set(_common_compile_options -Wno-deprecated-declarations -fPIC) +set(_common_compile_options -Wno-deprecated-declarations -fPIC -ffunction-sections -fdata-sections) # Let files say "include ". set(_common_include_directories ${EXECUTORCH_ROOT}/..) @@ -123,13 +123,25 @@ gen_selected_ops( ) generate_bindings_for_kernels( - LIB_NAME "select_build_lib" FUNCTIONS_YAML - ${EXECUTORCH_ROOT}/kernels/portable/functions.yaml CUSTOM_OPS_YAML + LIB_NAME + "select_build_lib" + FUNCTIONS_YAML + ${EXECUTORCH_ROOT}/kernels/portable/functions.yaml + CUSTOM_OPS_YAML "${_custom_ops_yaml}" + DTYPE_SELECTIVE_BUILD + "${EXECUTORCH_DTYPE_SELECTIVE_BUILD}" ) gen_operators_lib( - LIB_NAME "select_build_lib" KERNEL_LIBS ${_kernel_lib} DEPS executorch_core + LIB_NAME + "select_build_lib" + KERNEL_LIBS + ${_kernel_lib} + DEPS + executorch_core + DTYPE_SELECTIVE_BUILD + "${EXECUTORCH_DTYPE_SELECTIVE_BUILD}" ) list(TRANSFORM _executor_runner__srcs PREPEND "${EXECUTORCH_ROOT}/") diff --git a/examples/selective_build/test_selective_build.sh b/examples/selective_build/test_selective_build.sh index 32097b28170..eb9038b7cdd 100644 --- a/examples/selective_build/test_selective_build.sh +++ b/examples/selective_build/test_selective_build.sh @@ -162,13 +162,17 @@ test_cmake_select_ops_in_yaml() { } test_cmake_select_ops_in_model() { - echo "Exporting MobilenetV2" - ${PYTHON_EXECUTABLE} -m examples.portable.scripts.export --model_name="mv2" + local model_name="add_mul" + local model_export_name="${model_name}.pte" + echo "Exporting ${model_name}" + ${PYTHON_EXECUTABLE} -m examples.portable.scripts.export --model_name="${model_name}" local example_dir=examples/selective_build local build_dir=cmake-out/${example_dir} rm -rf ${build_dir} - retry cmake -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE \ - -DEXECUTORCH_SELECT_OPS_FROM_MODEL="./mv2.pte" \ + retry cmake -DCMAKE_BUILD_TYPE="$CMAKE_BUILD_TYPE" \ + -DEXECUTORCH_SELECT_OPS_FROM_MODEL="./${model_export_name}" \ + -DEXECUTORCH_DTYPE_SELECTIVE_BUILD=ON \ + -DEXECUTORCH_OPTIMIZE_SIZE=ON \ -DCMAKE_INSTALL_PREFIX=cmake-out \ -DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \ -B${build_dir} \ @@ -178,10 +182,10 @@ test_cmake_select_ops_in_model() { cmake --build ${build_dir} -j9 --config $CMAKE_BUILD_TYPE echo 'Running selective build test' - ${build_dir}/selective_build_test --model_path="./mv2.pte" + ${build_dir}/selective_build_test --model_path="./${model_export_name}" - echo "Removing mv2.pte" - rm "./mv2.pte" + echo "Removing ${model_export_name}" + rm "./${model_export_name}" } if [[ -z $BUCK ]]; diff --git a/examples/xnnpack/README.md b/examples/xnnpack/README.md index 6fe1f0488b2..ad09bb90d37 100644 --- a/examples/xnnpack/README.md +++ b/examples/xnnpack/README.md @@ -38,6 +38,7 @@ mkdir cmake-out cmake \ -DCMAKE_INSTALL_PREFIX=cmake-out \ -DCMAKE_BUILD_TYPE=Release \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ diff --git a/extension/apple/CMakeLists.txt b/extension/apple/CMakeLists.txt index d55fa381375..0e978073aa2 100644 --- a/extension/apple/CMakeLists.txt +++ b/extension/apple/CMakeLists.txt @@ -70,13 +70,14 @@ file(WRITE ${MODULE_MAP_FILE} ") set(SWIFT_CLANG_INTEROP_FLAGS "-Xcc -fmodule-map-file=${MODULE_MAP_FILE} -I ${MODULE_MAP_DIR}") +set(SWIFT_REMAP_FLAGS "-debug-prefix-map ${PROJECT_SOURCE_DIR}=/executorch") set_target_properties(extension_apple PROPERTIES Swift_MODULE_NAME "ExecuTorch" - Swift_FLAGS "${SWIFT_CLANG_INTEROP_FLAGS}" + Swift_FLAGS "${SWIFT_CLANG_INTEROP_FLAGS} ${SWIFT_REMAP_FLAGS}" XCODE_ATTRIBUTE_SWIFT_MODULE_NAME "ExecuTorch" XCODE_ATTRIBUTE_BUILD_LIBRARY_FOR_DISTRIBUTION "YES" - XCODE_ATTRIBUTE_OTHER_SWIFT_FLAGS "${SWIFT_CLANG_INTEROP_FLAGS}" + XCODE_ATTRIBUTE_OTHER_SWIFT_FLAGS "${SWIFT_CLANG_INTEROP_FLAGS} ${SWIFT_REMAP_FLAGS}" ) add_custom_command( diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorch+Module.swift b/extension/apple/ExecuTorch/Exported/ExecuTorch+Module.swift index e097a9253de..01eb24d15be 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorch+Module.swift +++ b/extension/apple/ExecuTorch/Exported/ExecuTorch+Module.swift @@ -60,7 +60,7 @@ public extension Module { /// - Returns: An array of `Value` objects representing the outputs. /// - Throws: An error if method execution fails. func execute(_ method: String, _ inputs: [ValueConvertible]) throws -> [Value] { - try __executeMethod(method, withInputs: inputs.map { $0.objcValue() } ) + try __executeMethod(method, withInputs: inputs.map { $0.asValue() } ) } /// Executes a specific method with a single input value. @@ -72,7 +72,7 @@ public extension Module { /// - Returns: An array of `Value` objects representing the outputs. /// - Throws: An error if method execution fails. func execute(_ method: String, _ input: ValueConvertible) throws -> [Value] { - try __executeMethod(method, withInputs: [input.objcValue()]) + try __executeMethod(method, withInputs: [input.asValue()]) } /// Executes the "forward" method with the provided input values. @@ -82,7 +82,7 @@ public extension Module { /// - Returns: An array of `Value` objects representing the outputs. /// - Throws: An error if method execution fails. func forward(_ inputs: [ValueConvertible]) throws -> [Value] { - try __executeMethod("forward", withInputs: inputs.map { $0.objcValue() }) + try __executeMethod("forward", withInputs: inputs.map { $0.asValue() }) } /// Executes the "forward" method with a single input value. @@ -92,6 +92,6 @@ public extension Module { /// - Returns: An array of `Value` objects representing the outputs. /// - Throws: An error if method execution fails. func forward(_ input: ValueConvertible) throws -> [Value] { - try __executeMethod("forward", withInputs: [input.objcValue()]) + try __executeMethod("forward", withInputs: [input.asValue()]) } } diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift b/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift index 02270e0b149..a3873794f9d 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift +++ b/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift @@ -8,6 +8,15 @@ @_exported import ExecuTorch +/// Computes the total number of elements in a tensor based on its shape. +/// +/// - Parameter shape: An array of integers, where each element represents a dimension size. +/// - Returns: An integer equal to the product of the sizes of all dimensions. +@available(*, deprecated, message: "This API is experimental.") +public func elementCount(ofShape shape: [Int]) -> Int { + __ExecuTorchElementCountOfShape(shape.map(NSNumber.init)) +} + /// A protocol that types conform to in order to be used as tensor element types. /// Provides the mapping from the Swift type to the underlying `DataType`. @available(*, deprecated, message: "This API is experimental.") @@ -122,36 +131,482 @@ extension UInt: Scalar { public func asNSNumber() -> NSNumber { NSNumber(value: self) } } -/// A tensor class for ExecuTorch operations. +/// A type-erasing tensor class for ExecuTorch operations. +@available(*, deprecated, message: "This API is experimental.") +public extension AnyTensor { + /// The shape of the tensor. + var shape: [Int] { __shape.map(\.intValue) } + + /// The strides of the tensor. + var strides: [Int] { __strides.map(\.intValue) } + + /// The order of dimensions in the tensor. + var dimensionOrder: [Int] { __dimensionOrder.map(\.intValue) } + + /// The total number of elements in the tensor. + var count: Int { __count } + + /// Initializes a tensor without copying the provided data. + /// + /// - Parameters: + /// - pointer: A pointer to the data buffer. + /// - shape: An array of integers representing the tensor's shape. + /// - strides: An array of integers representing the tensor's strides. + /// - dimensionOrder: An array of integers indicating the order of dimensions. + /// - dataType: A `DataType` value specifying the element type. + /// - shapeDynamism: A `ShapeDynamism` value indicating whether the shape is static or dynamic. + convenience init( + bytesNoCopy pointer: UnsafeMutableRawPointer, + shape: [Int], + strides: [Int] = [], + dimensionOrder: [Int] = [], + dataType: DataType, + shapeDynamism: ShapeDynamism = .dynamicBound + ) { + self.init( + __bytesNoCopy: pointer, + shape: shape.map(NSNumber.init), + strides: strides.map(NSNumber.init), + dimensionOrder: dimensionOrder.map(NSNumber.init), + dataType: dataType, + shapeDynamism: shapeDynamism + ) + } + + /// Initializes a tensor by copying bytes from the provided pointer. + /// + /// - Parameters: + /// - pointer: A pointer to the source data buffer. + /// - shape: An array of integers representing the tensor's shape. + /// - strides: An array of integers representing the tensor's strides. + /// - dimensionOrder: An array of integers indicating the order of dimensions. + /// - dataType: A `DataType` value specifying the element type. + /// - shapeDynamism: A `ShapeDynamism` value indicating the shape dynamism. + convenience init( + bytes pointer: UnsafeRawPointer, + shape: [Int], + strides: [Int] = [], + dimensionOrder: [Int] = [], + dataType: DataType, + shapeDynamism: ShapeDynamism = .dynamicBound + ) { + self.init( + __bytes: pointer, + shape: shape.map(NSNumber.init), + strides: strides.map(NSNumber.init), + dimensionOrder: dimensionOrder.map(NSNumber.init), + dataType: dataType, + shapeDynamism: shapeDynamism + ) + } + + /// Initializes a tensor using a `Data` object. The tensor holds a reference + /// to the `Data` object to ensure its buffer remains alive. The data is not copied. + /// + /// - Parameters: + /// - data: A `Data` object containing the tensor data. + /// - shape: An array of integers representing the tensor's shape. + /// - strides: An array of integers representing the tensor's strides. + /// - dimensionOrder: An array of integers indicating the order of dimensions. + /// - dataType: A `DataType` value specifying the element type. + /// - shapeDynamism: A `ShapeDynamism` value indicating the shape dynamism. + convenience init( + data: Data, + shape: [Int], + strides: [Int] = [], + dimensionOrder: [Int] = [], + dataType: DataType, + shapeDynamism: ShapeDynamism = .dynamicBound + ) { + self.init( + __data: data, + shape: shape.map(NSNumber.init), + strides: strides.map(NSNumber.init), + dimensionOrder: dimensionOrder.map(NSNumber.init), + dataType: dataType, + shapeDynamism: shapeDynamism + ) + } + + /// Resizes the tensor to a new shape. + /// + /// - Parameter shape: An array of `Int` representing the desired new shape. + /// - Throws: An error if the resize operation fails. + func resize(to shape: [Int]) throws { + try __resize(toShape: shape.map(NSNumber.init)) + } + + // MARK: Equatable + + /// Determines whether the current tensor is equal to another tensor. + /// + /// - Parameters: + /// - lhs: The left-hand side tensor. + /// - rhs: The right-hand side tensor. + /// - Returns: `true` if the tensors have the same type, shape, strides, and data; otherwise, `false`. + static func == (lhs: AnyTensor, rhs: AnyTensor) -> Bool { + lhs.__isEqual(to: rhs) + } + + /// Attempts to convert this type-erased `AnyTensor` into a strongly-typed `Tensor`. + /// + /// - Returns: An `AnyTensor` if `self.dataType == T.dataType`, + /// otherwise `nil` when the runtime dtype doesn’t match. + func asTensor() -> Tensor? { + guard dataType == T.dataType else { return nil } + return Tensor(self) + } +} + +@available(*, deprecated, message: "This API is experimental.") +public extension AnyTensor { + /// Creates an empty tensor with the specified properties. + /// + /// - Parameters: + /// - shape: An array of integers representing the desired shape. + /// - strides: An array of integers representing the desired strides. + /// - dataType: A `DataType` value specifying the element type. + /// - shapeDynamism: A value specifying whether the shape is static or dynamic. + /// - Returns: A new, empty `AnyTensor` instance. + static func empty( + shape: [Int], + strides: [Int] = [], + dataType: DataType, + shapeDynamism: ShapeDynamism = .dynamicBound + ) -> AnyTensor { + __empty( + withShape: shape.map(NSNumber.init), + strides: strides.map(NSNumber.init), + dataType: dataType, + shapeDynamism: shapeDynamism + ) + } + + /// Creates an empty tensor with the same properties as a given tensor. + /// + /// - Parameters: + /// - like: An existing `AnyTensor` instance whose shape and strides are used. + /// - dataType: A `DataType` value specifying the element type. + /// - shapeDynamism: A value specifying whether the shape is static or dynamic. + /// - Returns: A new, empty `AnyTensor` instance. + static func empty( + like tensor: AnyTensor, + dataType: DataType = .undefined, + shapeDynamism: ShapeDynamism = .dynamicBound + ) -> AnyTensor { + __emptyTensorLike( + tensor, + dataType: dataType == .undefined ? tensor.dataType : dataType, + shapeDynamism: shapeDynamism + ) + } +} + +@available(*, deprecated, message: "This API is experimental.") +public extension AnyTensor { + /// Creates a tensor filled with the specified scalar value. + /// + /// - Parameters: + /// - shape: An array of integers representing the desired shape. + /// - scalar: The value to fill the tensor with. + /// - strides: An array of integers representing the desired strides. + /// - shapeDynamism: A value specifying whether the shape is static or dynamic. + /// - Returns: A new `AnyTensor` instance filled with the scalar value. + static func full( + shape: [Int], + scalar: T, + strides: [Int] = [], + shapeDynamism: ShapeDynamism = .dynamicBound + ) -> AnyTensor { + __fullTensor( + withShape: shape.map(NSNumber.init), + scalar: scalar.asNSNumber(), + strides: strides.map(NSNumber.init), + dataType: T.dataType, + shapeDynamism: shapeDynamism + ) + } + + /// Creates a tensor filled with a scalar value, with the same properties as a given tensor. + /// + /// - Parameters: + /// - like: An existing `AnyTensor` instance whose shape and strides are used. + /// - scalar: The value to fill the tensor with. + /// - shapeDynamism: A value specifying whether the shape is static or dynamic. + /// - Returns: A new `AnyTensor` instance filled with the scalar value. + static func full( + like tensor: AnyTensor, + scalar: T, + shapeDynamism: ShapeDynamism = .dynamicBound + ) -> AnyTensor { + __fullTensorLike( + tensor, + scalar: scalar.asNSNumber(), + dataType: T.dataType, + shapeDynamism: shapeDynamism + ) + } +} + +@available(*, deprecated, message: "This API is experimental.") +public extension AnyTensor { + /// Creates a tensor filled with ones. + /// + /// - Parameters: + /// - shape: An array of integers representing the desired shape. + /// - strides: An array of integers representing the desired strides. + /// - dataType: A `DataType` value specifying the element type. + /// - shapeDynamism: A value specifying whether the shape is static or dynamic. + /// - Returns: A new `AnyTensor` instance filled with ones. + static func ones( + shape: [Int], + strides: [Int] = [], + dataType: DataType, + shapeDynamism: ShapeDynamism = .dynamicBound + ) -> AnyTensor { + __onesTensor( + withShape: shape.map(NSNumber.init), + dataType: dataType, + shapeDynamism: shapeDynamism + ) + } + + /// Creates a tensor of ones with the same properties as a given tensor. + /// + /// - Parameters: + /// - like: An existing `AnyTensor` instance whose shape and strides are used. + /// - shapeDynamism: A value specifying whether the shape is static or dynamic. + /// - Returns: A new `AnyTensor` instance filled with ones. + static func ones( + like tensor: AnyTensor, + dataType: DataType = .undefined, + shapeDynamism: ShapeDynamism = .dynamicBound + ) -> AnyTensor { + __onesTensorLike( + tensor, + dataType: dataType == .undefined ? tensor.dataType : dataType, + shapeDynamism: shapeDynamism + ) + } +} + +@available(*, deprecated, message: "This API is experimental.") +public extension AnyTensor { + /// Creates a tensor filled with zeros. + /// + /// - Parameters: + /// - shape: An array of integers representing the desired shape. + /// - strides: An array of integers representing the desired strides. + /// - dataType: A `DataType` value specifying the element type. + /// - shapeDynamism: A value specifying whether the shape is static or dynamic. + /// - Returns: A new `AnyTensor` instance filled with zeros. + static func zeros( + shape: [Int], + strides: [Int] = [], + dataType: DataType, + shapeDynamism: ShapeDynamism = .dynamicBound + ) -> AnyTensor { + __zerosTensor( + withShape: shape.map(NSNumber.init), + dataType: dataType, + shapeDynamism: shapeDynamism + ) + } + + /// Creates a tensor of zeros with the same properties as a given tensor. + /// + /// - Parameters: + /// - like: An existing `AnyTensor` instance whose shape and strides are used. + /// - dataType: A `DataType` value specifying the element type. + /// - shapeDynamism: A value specifying whether the shape is static or dynamic. + /// - Returns: A new `AnyTensor` instance filled with zeros. + static func zeros( + like tensor: AnyTensor, + dataType: DataType = .undefined, + shapeDynamism: ShapeDynamism = .dynamicBound + ) -> AnyTensor { + __zerosTensorLike( + tensor, + dataType: dataType == .undefined ? tensor.dataType : dataType, + shapeDynamism: shapeDynamism + ) + } +} + +@available(*, deprecated, message: "This API is experimental.") +public extension AnyTensor { + /// Creates a tensor with random values uniformly distributed in `[0, 1)`. + /// + /// - Parameters: + /// - shape: An array of integers representing the desired shape. + /// - strides: An array of integers representing the desired strides. + /// - dataType: A `DataType` value specifying the element type. + /// - shapeDynamism: A value specifying whether the shape is static or dynamic. + /// - Returns: A new `AnyTensor` instance filled with random values. + static func rand( + shape: [Int], + strides: [Int] = [], + dataType: DataType, + shapeDynamism: ShapeDynamism = .dynamicBound + ) -> AnyTensor { + __randomTensor( + withShape: shape.map(NSNumber.init), + strides: strides.map(NSNumber.init), + dataType: dataType, + shapeDynamism: shapeDynamism + ) + } + + /// Creates a tensor with random values with the same properties as a given tensor. + /// + /// - Parameters: + /// - like: An existing `AnyTensor` instance whose shape and strides are used. + /// - dataType: A `DataType` value specifying the element type. + /// - shapeDynamism: A value specifying whether the shape is static or dynamic. + /// - Returns: A new `AnyTensor` instance filled with random values. + static func rand( + like tensor: AnyTensor, + dataType: DataType = .undefined, + shapeDynamism: ShapeDynamism = .dynamicBound + ) -> AnyTensor { + __randomTensorLike( + tensor, + dataType: dataType == .undefined ? tensor.dataType : dataType, + shapeDynamism: shapeDynamism + ) + } +} + +@available(*, deprecated, message: "This API is experimental.") +public extension AnyTensor { + /// Creates a tensor with random values from a normal distribution with mean `0` and variance `1`. + /// + /// - Parameters: + /// - shape: An array of integers representing the desired shape. + /// - strides: An array of integers representing the desired strides. + /// - dataType: A `DataType` value specifying the element type. + /// - shapeDynamism: A value specifying whether the shape is static or dynamic. + /// - Returns: A new `AnyTensor` instance filled with values from a normal distribution. + static func randn( + shape: [Int], + strides: [Int] = [], + dataType: DataType, + shapeDynamism: ShapeDynamism = .dynamicBound + ) -> AnyTensor { + __randomNormalTensor( + withShape: shape.map(NSNumber.init), + strides: strides.map(NSNumber.init), + dataType: dataType, + shapeDynamism: shapeDynamism + ) + } + + /// Creates a tensor with random normal values with the same properties as a given tensor. + /// + /// - Parameters: + /// - like: An existing `AnyTensor` instance whose shape and strides are used. + /// - dataType: A `DataType` value specifying the element type. + /// - shapeDynamism: A value specifying whether the shape is static or dynamic. + /// - Returns: A new `AnyTensor` instance filled with values from a normal distribution. + static func randn( + like tensor: AnyTensor, + dataType: DataType = .undefined, + shapeDynamism: ShapeDynamism = .dynamicBound + ) -> AnyTensor { + __randomNormalTensorLike( + tensor, + dataType: dataType == .undefined ? tensor.dataType : dataType, + shapeDynamism: shapeDynamism + ) + } +} + +@available(*, deprecated, message: "This API is experimental.") +public extension AnyTensor { + /// Creates a tensor with random integers from `low` (inclusive) to `high` (exclusive). + /// + /// - Parameters: + /// - low: The inclusive lower bound of the random integer range. + /// - high: The exclusive upper bound of the random integer range. + /// - shape: An array of integers representing the desired shape. + /// - strides: An array of integers representing the desired strides. + /// - dataType: A `DataType` value specifying the element type. + /// - shapeDynamism: A value specifying whether the shape is static or dynamic. + /// - Returns: A new `AnyTensor` instance filled with random integer values. + static func randint( + low: Int, + high: Int, + shape: [Int], + strides: [Int] = [], + dataType: DataType, + shapeDynamism: ShapeDynamism = .dynamicBound + ) -> AnyTensor { + __randomIntegerTensor( + withLow: low, + high: high, + shape: shape.map(NSNumber.init), + strides: strides.map(NSNumber.init), + dataType: dataType, + shapeDynamism: shapeDynamism + ) + } + + /// Creates a tensor with random integers with the same properties as a given tensor. + /// + /// - Parameters: + /// - like: An existing `AnyTensor` instance whose shape and strides are used. + /// - low: The inclusive lower bound of the random integer range. + /// - high: The exclusive upper bound of the random integer range. + /// - dataType: A `DataType` value specifying the element type. + /// - shapeDynamism: A value specifying whether the shape is static or dynamic. + /// - Returns: A new `AnyTensor` instance filled with random integer values. + static func randint( + like tensor: AnyTensor, + low: Int, + high: Int, + dataType: DataType = .undefined, + shapeDynamism: ShapeDynamism = .dynamicBound + ) -> AnyTensor { + __randomIntegerTensorLike( + tensor, + low: low, + high: high, + dataType: dataType == .undefined ? tensor.dataType : dataType, + shapeDynamism: shapeDynamism + ) + } +} + +/// A generic tensor class for ExecuTorch operations. /// -/// This class encapsulates a native `ExecuTorchTensor` instance and provides a variety of +/// This class encapsulates a type-erasing `AnyTensor` instance and provides a variety of /// initializers and utility methods to work with tensor data. @available(*, deprecated, message: "This API is experimental.") public class Tensor: Equatable { /// The data type of the tensor's elements. - public var dataType: DataType { objcTensor.dataType } + public var dataType: DataType { anyTensor.dataType } /// The shape of the tensor. - public var shape: [Int] { objcTensor.shape.map(\.intValue) } + public var shape: [Int] { anyTensor.shape } /// The strides of the tensor. - public var strides: [Int] { objcTensor.strides.map(\.intValue) } + public var strides: [Int] { anyTensor.strides } /// The order of dimensions in the tensor. - public var dimensionOrder: [Int] { objcTensor.dimensionOrder.map(\.intValue) } + public var dimensionOrder: [Int] { anyTensor.dimensionOrder } /// The dynamism of the tensor's shape. - public var shapeDynamism: ShapeDynamism { objcTensor.shapeDynamism } + public var shapeDynamism: ShapeDynamism { anyTensor.shapeDynamism } /// The total number of elements in the tensor. - public var count: Int { objcTensor.count } + public var count: Int { anyTensor.count } - /// Initializes a tensor with an `ExecuTorchTensor` instance. + /// Initializes a tensor with an `AnyTensor` instance. /// - /// - Parameter tensor: An `ExecuTorchTensor` instance. - public init(_ tensor: __ExecuTorchTensor) { + /// - Parameter tensor: An `AnyTensor` instance. + public init(_ tensor: AnyTensor) { precondition(tensor.dataType == T.dataType) - objcTensor = tensor + anyTensor = tensor } /// Creates a new tensor that shares the underlying data storage with the @@ -159,7 +614,7 @@ public class Tensor: Equatable { /// /// - Parameter tensor: The tensor to create a view of. public convenience init(_ tensor: Tensor) { - self.init(__ExecuTorchTensor(tensor.objcTensor)) + self.init(AnyTensor(tensor.anyTensor)) } /// Initializes a tensor without copying the provided data. @@ -177,11 +632,11 @@ public class Tensor: Equatable { dimensionOrder: [Int] = [], shapeDynamism: ShapeDynamism = .dynamicBound ) { - self.init(__ExecuTorchTensor( + self.init(AnyTensor( bytesNoCopy: pointer, - shape: shape.map(NSNumber.init), - strides: strides.map(NSNumber.init), - dimensionOrder: dimensionOrder.map(NSNumber.init), + shape: shape, + strides: strides, + dimensionOrder: dimensionOrder, dataType: T.dataType, shapeDynamism: shapeDynamism )) @@ -202,11 +657,11 @@ public class Tensor: Equatable { dimensionOrder: [Int] = [], shapeDynamism: ShapeDynamism = .dynamicBound ) { - self.init(__ExecuTorchTensor( + self.init(AnyTensor( bytes: pointer, - shape: shape.map(NSNumber.init), - strides: strides.map(NSNumber.init), - dimensionOrder: dimensionOrder.map(NSNumber.init), + shape: shape, + strides: strides, + dimensionOrder: dimensionOrder, dataType: T.dataType, shapeDynamism: shapeDynamism )) @@ -228,11 +683,11 @@ public class Tensor: Equatable { dimensionOrder: [Int] = [], shapeDynamism: ShapeDynamism = .dynamicBound ) { - self.init(__ExecuTorchTensor( + self.init(AnyTensor( data: data, - shape: shape.map(NSNumber.init), - strides: strides.map(NSNumber.init), - dimensionOrder: dimensionOrder.map(NSNumber.init), + shape: shape, + strides: strides, + dimensionOrder: dimensionOrder, dataType: T.dataType, shapeDynamism: shapeDynamism )) @@ -253,14 +708,14 @@ public class Tensor: Equatable { dimensionOrder: [Int] = [], shapeDynamism: ShapeDynamism = .dynamicBound ) { - let nsShape = (shape.isEmpty ? [scalars.count] : shape).map(NSNumber.init) - precondition(scalars.count == elementCount(ofShape: nsShape)) - self.init(scalars.withUnsafeBufferPointer { buffer in - __ExecuTorchTensor( - bytes: buffer.baseAddress!, - shape: nsShape, - strides: strides.map(NSNumber.init), - dimensionOrder: dimensionOrder.map(NSNumber.init), + let newShape = shape.isEmpty ? [scalars.count] : shape + precondition(scalars.count == elementCount(ofShape: newShape)) + self.init(scalars.withUnsafeBufferPointer { + AnyTensor( + bytes: $0.baseAddress!, + shape: newShape, + strides: strides, + dimensionOrder: dimensionOrder, dataType: T.dataType, shapeDynamism: shapeDynamism ) @@ -271,14 +726,14 @@ public class Tensor: Equatable { /// /// - Parameter scalar: A scalar value. public convenience init(_ scalar: T) { - self.init(__ExecuTorchTensor(scalar.asNSNumber(), dataType: T.dataType)) + self.init(AnyTensor(__scalar: scalar.asNSNumber(), dataType: T.dataType)) } /// Returns a copy of the tensor. /// /// - Returns: A new `Tensor` instance that is a duplicate of the current tensor. public func copy() -> Tensor { - Tensor(objcTensor.copy()) + Tensor(anyTensor.copy()) } /// Calls the closure with a typed, immutable buffer pointer over the tensor’s elements. @@ -288,7 +743,7 @@ public class Tensor: Equatable { /// - Throws: Any error thrown by `body`. public func withUnsafeBytes(_ body: (UnsafeBufferPointer) throws -> R) throws -> R { var result: Result? - objcTensor.bytes { pointer, count, _ in + anyTensor.bytes { pointer, count, _ in result = Result { try body( UnsafeBufferPointer( start: pointer.assumingMemoryBound(to: T.self), @@ -306,7 +761,7 @@ public class Tensor: Equatable { /// - Throws: Any error thrown by `body`. public func withUnsafeMutableBytes(_ body: (UnsafeMutableBufferPointer) throws -> R) throws -> R { var result: Result? - objcTensor.mutableBytes { pointer, count, _ in + anyTensor.mutableBytes { pointer, count, _ in result = Result { try body( UnsafeMutableBufferPointer( start: pointer.assumingMemoryBound(to: T.self), @@ -322,7 +777,7 @@ public class Tensor: Equatable { /// - Parameter shape: An array of `Int` representing the desired new shape. /// - Throws: An error if the resize operation fails. public func resize(to shape: [Int]) throws { - try objcTensor.resize(to: shape.map(NSNumber.init)) + try anyTensor.resize(to: shape) } // MARK: Equatable @@ -334,12 +789,12 @@ public class Tensor: Equatable { /// - rhs: The right-hand side tensor. /// - Returns: `true` if the tensors have the same type, shape, strides, and data; otherwise, `false`. public static func == (lhs: Tensor, rhs: Tensor) -> Bool { - lhs.objcTensor.isEqual(to: rhs.objcTensor) + lhs.anyTensor == rhs.anyTensor } // MARK: Internal - let objcTensor: __ExecuTorchTensor + let anyTensor: AnyTensor } @available(*, deprecated, message: "This API is experimental.") @@ -367,9 +822,9 @@ public extension Tensor { strides: [Int] = [], shapeDynamism: ShapeDynamism = .dynamicBound ) -> Tensor { - Tensor(__ExecuTorchTensor.empty( - shape: shape.map(NSNumber.init), - strides: strides.map(NSNumber.init), + Tensor(AnyTensor.empty( + shape: shape, + strides: strides, dataType: T.dataType, shapeDynamism: shapeDynamism )) @@ -382,12 +837,11 @@ public extension Tensor { /// - shapeDynamism: A value specifying whether the shape is static or dynamic. /// - Returns: A new, empty `Tensor` instance. static func empty( - like: Tensor, + like tensor: Tensor, shapeDynamism: ShapeDynamism = .dynamicBound ) -> Tensor { - Tensor(__ExecuTorchTensor.empty( - like: like.objcTensor, - dataType: T.dataType, + Tensor(AnyTensor.empty( + like: tensor.anyTensor, shapeDynamism: shapeDynamism )) } @@ -409,11 +863,10 @@ public extension Tensor { strides: [Int] = [], shapeDynamism: ShapeDynamism = .dynamicBound ) -> Tensor { - Tensor(__ExecuTorchTensor.full( - shape: shape.map(NSNumber.init), - scalar: scalar.asNSNumber(), - strides: strides.map(NSNumber.init), - dataType: T.dataType, + Tensor(AnyTensor.full( + shape: shape, + scalar: scalar, + strides: strides, shapeDynamism: shapeDynamism )) } @@ -426,14 +879,13 @@ public extension Tensor { /// - shapeDynamism: A value specifying whether the shape is static or dynamic. /// - Returns: A new `Tensor` instance filled with the scalar value. static func full( - like: Tensor, + like tensor: Tensor, scalar: T, shapeDynamism: ShapeDynamism = .dynamicBound ) -> Tensor { - Tensor(__ExecuTorchTensor.full( - like: like.objcTensor, - scalar: scalar.asNSNumber(), - dataType: T.dataType, + Tensor(AnyTensor.full( + like: tensor.anyTensor, + scalar: scalar, shapeDynamism: shapeDynamism )) } @@ -453,8 +905,8 @@ public extension Tensor { strides: [Int] = [], shapeDynamism: ShapeDynamism = .dynamicBound ) -> Tensor { - Tensor(__ExecuTorchTensor.ones( - shape: shape.map(NSNumber.init), + Tensor(AnyTensor.ones( + shape: shape, dataType: T.dataType, shapeDynamism: shapeDynamism )) @@ -467,12 +919,11 @@ public extension Tensor { /// - shapeDynamism: A value specifying whether the shape is static or dynamic. /// - Returns: A new `Tensor` instance filled with ones. static func ones( - like: Tensor, + like tensor: Tensor, shapeDynamism: ShapeDynamism = .dynamicBound ) -> Tensor { - Tensor(__ExecuTorchTensor.ones( - like: like.objcTensor, - dataType: T.dataType, + Tensor(AnyTensor.ones( + like: tensor.anyTensor, shapeDynamism: shapeDynamism )) } @@ -492,8 +943,8 @@ public extension Tensor { strides: [Int] = [], shapeDynamism: ShapeDynamism = .dynamicBound ) -> Tensor { - Tensor(__ExecuTorchTensor.zeros( - shape: shape.map(NSNumber.init), + Tensor(AnyTensor.zeros( + shape: shape, dataType: T.dataType, shapeDynamism: shapeDynamism )) @@ -506,12 +957,11 @@ public extension Tensor { /// - shapeDynamism: A value specifying whether the shape is static or dynamic. /// - Returns: A new `Tensor` instance filled with zeros. static func zeros( - like: Tensor, + like tensor: Tensor, shapeDynamism: ShapeDynamism = .dynamicBound ) -> Tensor { - Tensor(__ExecuTorchTensor.zeros( - like: like.objcTensor, - dataType: T.dataType, + Tensor(AnyTensor.zeros( + like: tensor.anyTensor, shapeDynamism: shapeDynamism )) } @@ -531,8 +981,9 @@ public extension Tensor { strides: [Int] = [], shapeDynamism: ShapeDynamism = .dynamicBound ) -> Tensor { - Tensor(__ExecuTorchTensor.rand( - shape: shape.map(NSNumber.init), + Tensor(AnyTensor.rand( + shape: shape, + strides: strides, dataType: T.dataType, shapeDynamism: shapeDynamism )) @@ -545,12 +996,11 @@ public extension Tensor { /// - shapeDynamism: A value specifying whether the shape is static or dynamic. /// - Returns: A new `Tensor` instance filled with random values. static func rand( - like: Tensor, + like tensor: Tensor, shapeDynamism: ShapeDynamism = .dynamicBound ) -> Tensor { - Tensor(__ExecuTorchTensor.rand( - like: like.objcTensor, - dataType: T.dataType, + Tensor(AnyTensor.rand( + like: tensor.anyTensor, shapeDynamism: shapeDynamism )) } @@ -570,8 +1020,9 @@ public extension Tensor { strides: [Int] = [], shapeDynamism: ShapeDynamism = .dynamicBound ) -> Tensor { - Tensor(__ExecuTorchTensor.randn( - shape: shape.map(NSNumber.init), + Tensor(AnyTensor.randn( + shape: shape, + strides: strides, dataType: T.dataType, shapeDynamism: shapeDynamism )) @@ -584,12 +1035,11 @@ public extension Tensor { /// - shapeDynamism: A value specifying whether the shape is static or dynamic. /// - Returns: A new `Tensor` instance filled with values from a normal distribution. static func randn( - like: Tensor, + like tensor: Tensor, shapeDynamism: ShapeDynamism = .dynamicBound ) -> Tensor { - Tensor(__ExecuTorchTensor.randn( - like: like.objcTensor, - dataType: T.dataType, + Tensor(AnyTensor.randn( + like: tensor.anyTensor, shapeDynamism: shapeDynamism )) } @@ -613,10 +1063,11 @@ public extension Tensor { strides: [Int] = [], shapeDynamism: ShapeDynamism = .dynamicBound ) -> Tensor { - Tensor(__ExecuTorchTensor.randint( + Tensor(AnyTensor.randint( low: low, high: high, - shape: shape.map(NSNumber.init), + shape: shape, + strides: strides, dataType: T.dataType, shapeDynamism: shapeDynamism )) @@ -631,17 +1082,23 @@ public extension Tensor { /// - shapeDynamism: A value specifying whether the shape is static or dynamic. /// - Returns: A new `Tensor` instance filled with random integer values. static func randint( - like: Tensor, + like tensor: Tensor, low: Int, high: Int, shapeDynamism: ShapeDynamism = .dynamicBound ) -> Tensor { - Tensor(__ExecuTorchTensor.randint( - like: like.objcTensor, + Tensor(AnyTensor.randint( + like: tensor.anyTensor, low: low, high: high, - dataType: T.dataType, shapeDynamism: shapeDynamism )) } } + +@available(*, deprecated, message: "This API is experimental.") +extension Tensor: CustomStringConvertible { + public var description: String { + self.anyTensor.description + } +} diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorch+Value.swift b/extension/apple/ExecuTorch/Exported/ExecuTorch+Value.swift index 647509fd0c1..148b8f03cf0 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorch+Value.swift +++ b/extension/apple/ExecuTorch/Exported/ExecuTorch+Value.swift @@ -13,7 +13,7 @@ @available(*, deprecated, message: "This API is experimental.") public protocol ValueConvertible { /// Converts the instance into a `Value`. - func objcValue() -> Value + func asValue() -> Value } @available(*, deprecated, message: "This API is experimental.") @@ -22,7 +22,14 @@ public extension Value { /// /// - Parameter tensor: The `Tensor` to wrap. convenience init(_ tensor: Tensor) { - self.init(__tensor: tensor.objcTensor) + self.init(tensor.anyTensor) + } + + /// Attempts to return the underlying type-erased `AnyTensor` if the `Value` contains one. + /// + /// - Returns: An `AnyTensor`, or `nil` if the `Value` is not a tensor. + var anyTensor: AnyTensor? { + __tensorValue } /// Attempts to return the underlying `Tensor` if the `Value` contains one. @@ -30,8 +37,7 @@ public extension Value { /// - Returns: A `Tensor` of the specified scalar type, or `nil` if the /// `Value` is not a tensor or the data type does not match. func tensor() -> Tensor? { - guard isTensor, let tensor = __tensorValue, tensor.dataType == T.dataType else { return nil } - return Tensor(tensor) + anyTensor?.asTensor() } } @@ -40,101 +46,107 @@ public extension Value { @available(*, deprecated, message: "This API is experimental.") extension Value: ValueConvertible { /// Returns the `Value` itself. - public func objcValue() -> Value { self } + public func asValue() -> Value { self } +} + +@available(*, deprecated, message: "This API is experimental.") +extension AnyTensor: ValueConvertible { + /// Converts the `Tensor` into a `Value`. + public func asValue() -> Value { Value(self) } } @available(*, deprecated, message: "This API is experimental.") extension Tensor: ValueConvertible { /// Converts the `Tensor` into a `Value`. - public func objcValue() -> Value { Value(self) } + public func asValue() -> Value { Value(self) } } @available(*, deprecated, message: "This API is experimental.") extension String: ValueConvertible { /// Converts the `String` into a `Value`. - public func objcValue() -> Value { Value(self) } + public func asValue() -> Value { Value(self) } } @available(*, deprecated, message: "This API is experimental.") extension NSNumber: ValueConvertible { /// Converts the `NSNumber` into a `Value`. - public func objcValue() -> Value { Value(self) } + public func asValue() -> Value { Value(self) } } @available(*, deprecated, message: "This API is experimental.") extension UInt8: ValueConvertible { /// Converts the `UInt8` into a `Value`. - public func objcValue() -> Value { Value(NSNumber(value: Int(self))) } + public func asValue() -> Value { Value(NSNumber(value: Int(self))) } } @available(*, deprecated, message: "This API is experimental.") extension Int8: ValueConvertible { /// Converts the `Int8` into a `Value`. - public func objcValue() -> Value { Value(NSNumber(value: Int(self))) } + public func asValue() -> Value { Value(NSNumber(value: Int(self))) } } @available(*, deprecated, message: "This API is experimental.") extension Int16: ValueConvertible { /// Converts the `Int16` into a `Value`. - public func objcValue() -> Value { Value(NSNumber(value: self)) } + public func asValue() -> Value { Value(NSNumber(value: self)) } } @available(*, deprecated, message: "This API is experimental.") extension Int32: ValueConvertible { /// Converts the `Int32` into a `Value`. - public func objcValue() -> Value { Value(NSNumber(value: self)) } + public func asValue() -> Value { Value(NSNumber(value: self)) } } @available(*, deprecated, message: "This API is experimental.") extension Int64: ValueConvertible { /// Converts the `Int64` into a `Value`. - public func objcValue() -> Value { Value(NSNumber(value: self)) } + public func asValue() -> Value { Value(NSNumber(value: self)) } } @available(*, deprecated, message: "This API is experimental.") extension Int: ValueConvertible { /// Converts the `Int` into a `Value`. - public func objcValue() -> Value { Value(self) } + public func asValue() -> Value { Value(self) } } @available(*, deprecated, message: "This API is experimental.") extension Float: ValueConvertible { /// Converts the `Float` into a `Value`. - public func objcValue() -> Value { Value(self) } + public func asValue() -> Value { Value(self) } } @available(*, deprecated, message: "This API is experimental.") extension Double: ValueConvertible { /// Converts the `Double` into a `Value`. - public func objcValue() -> Value { Value(self) } + public func asValue() -> Value { Value(self) } } @available(*, deprecated, message: "This API is experimental.") extension Bool: ValueConvertible { /// Converts the `Bool` into a `Value`. - public func objcValue() -> Value { Value(self) } + public func asValue() -> Value { Value(self) } } @available(*, deprecated, message: "This API is experimental.") extension UInt16: ValueConvertible { /// Converts the `UInt16` into a `Value`. - public func objcValue() -> Value { Value(NSNumber(value: self)) } + public func asValue() -> Value { Value(NSNumber(value: self)) } } @available(*, deprecated, message: "This API is experimental.") extension UInt32: ValueConvertible { /// Converts the `UInt32` into a `Value`. - public func objcValue() -> Value { Value(NSNumber(value: self)) } + public func asValue() -> Value { Value(NSNumber(value: self)) } } @available(*, deprecated, message: "This API is experimental.") extension UInt64: ValueConvertible { /// Converts the `UInt64` into a `Value`. - public func objcValue() -> Value { Value(NSNumber(value: self)) } + public func asValue() -> Value { Value(NSNumber(value: self)) } } @available(*, deprecated, message: "This API is experimental.") extension UInt: ValueConvertible { /// Converts the `UInt` into a `Value`. - public func objcValue() -> Value { Value(NSNumber(value: self)) } + public func asValue() -> Value { Value(NSNumber(value: self)) } } diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h index 1be47b42bbd..e4a6ce49cd3 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h @@ -81,7 +81,7 @@ NSInteger ExecuTorchSizeOfDataType(ExecuTorchDataType dataType) FOUNDATION_EXPORT __attribute__((deprecated("This API is experimental."))) NSInteger ExecuTorchElementCountOfShape(NSArray *shape) - NS_SWIFT_NAME(elementCount(ofShape:)); + NS_REFINED_FOR_SWIFT; /** * A tensor class for ExecuTorch operations. @@ -89,7 +89,7 @@ NSInteger ExecuTorchElementCountOfShape(NSArray *shape) * This class encapsulates a native TensorPtr instance and provides a variety of * initializers and utility methods to work with tensor data. */ -NS_REFINED_FOR_SWIFT + NS_SWIFT_NAME(AnyTensor) __attribute__((deprecated("This API is experimental."))) @interface ExecuTorchTensor : NSObject @@ -112,21 +112,21 @@ __attribute__((deprecated("This API is experimental."))) * * @return An NSArray of NSNumber objects representing the size of each dimension. */ -@property(nonatomic, readonly) NSArray *shape; +@property(nonatomic, readonly) NSArray *shape NS_REFINED_FOR_SWIFT; /** * The order of dimensions in the tensor. * * @return An NSArray of NSNumber objects representing the tensor’s dimension order. */ -@property(nonatomic, readonly) NSArray *dimensionOrder; +@property(nonatomic, readonly) NSArray *dimensionOrder NS_REFINED_FOR_SWIFT; /** * The strides of the tensor. * * @return An NSArray of NSNumber objects representing the step sizes for each dimension. */ -@property(nonatomic, readonly) NSArray *strides; +@property(nonatomic, readonly) NSArray *strides NS_REFINED_FOR_SWIFT; /** * The dynamism of the tensor's shape. @@ -140,7 +140,7 @@ __attribute__((deprecated("This API is experimental."))) * * @return An NSInteger representing the total element count. */ -@property(nonatomic, readonly) NSInteger count; +@property(nonatomic, readonly) NSInteger count NS_REFINED_FOR_SWIFT; /** * Initializes a tensor with a native TensorPtr instance. @@ -149,7 +149,8 @@ __attribute__((deprecated("This API is experimental."))) * @return An initialized ExecuTorchTensor instance. */ - (instancetype)initWithNativeInstance:(void *)nativeInstance - NS_DESIGNATED_INITIALIZER NS_SWIFT_UNAVAILABLE(""); + NS_DESIGNATED_INITIALIZER + NS_SWIFT_UNAVAILABLE(""); /** * Creates a new tensor that shares the underlying data storage with the @@ -200,7 +201,7 @@ __attribute__((deprecated("This API is experimental."))) */ - (BOOL)resizeToShape:(NSArray *)shape error:(NSError **)error - NS_SWIFT_NAME(resize(to:)); + NS_REFINED_FOR_SWIFT; /** * Determines whether the current tensor is equal to another tensor. @@ -209,7 +210,8 @@ __attribute__((deprecated("This API is experimental."))) * @return YES if the tensors have the same data type, shape, dimension order, * strides, and underlying data; otherwise, NO. */ -- (BOOL)isEqualToTensor:(nullable ExecuTorchTensor *)other; +- (BOOL)isEqualToTensor:(nullable ExecuTorchTensor *)other + NS_REFINED_FOR_SWIFT; + (instancetype)new NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE; @@ -236,7 +238,8 @@ __attribute__((deprecated("This API is experimental."))) strides:(NSArray *)strides dimensionOrder:(NSArray *)dimensionOrder dataType:(ExecuTorchDataType)dataType - shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism; + shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism + NS_REFINED_FOR_SWIFT; /** * Initializes a tensor without copying data using dynamic bound shape (default strides and dimension order). @@ -252,7 +255,8 @@ __attribute__((deprecated("This API is experimental."))) shape:(NSArray *)shape strides:(NSArray *)strides dimensionOrder:(NSArray *)dimensionOrder - dataType:(ExecuTorchDataType)dataType; + dataType:(ExecuTorchDataType)dataType + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor without copying data, with an explicit shape dynamism. @@ -266,7 +270,8 @@ __attribute__((deprecated("This API is experimental."))) - (instancetype)initWithBytesNoCopy:(void *)pointer shape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType - shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism; + shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor without copying data, specifying only the shape and data type. @@ -278,7 +283,8 @@ __attribute__((deprecated("This API is experimental."))) */ - (instancetype)initWithBytesNoCopy:(void *)pointer shape:(NSArray *)shape - dataType:(ExecuTorchDataType)dataType; + dataType:(ExecuTorchDataType)dataType + NS_SWIFT_UNAVAILABLE(""); @end @@ -302,7 +308,8 @@ __attribute__((deprecated("This API is experimental."))) strides:(NSArray *)strides dimensionOrder:(NSArray *)dimensionOrder dataType:(ExecuTorchDataType)dataType - shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism; + shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism + NS_REFINED_FOR_SWIFT; /** * Initializes a tensor by copying bytes from the provided pointer with dynamic bound shape. @@ -318,7 +325,8 @@ __attribute__((deprecated("This API is experimental."))) shape:(NSArray *)shape strides:(NSArray *)strides dimensionOrder:(NSArray *)dimensionOrder - dataType:(ExecuTorchDataType)dataType; + dataType:(ExecuTorchDataType)dataType + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor by copying bytes from the provided pointer, specifying shape, data type, and explicit shape dynamism. @@ -332,7 +340,8 @@ __attribute__((deprecated("This API is experimental."))) - (instancetype)initWithBytes:(const void *)pointer shape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType - shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism; + shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor by copying bytes from the provided pointer, specifying only the shape and data type. @@ -344,7 +353,8 @@ __attribute__((deprecated("This API is experimental."))) */ - (instancetype)initWithBytes:(const void *)pointer shape:(NSArray *)shape - dataType:(ExecuTorchDataType)dataType; + dataType:(ExecuTorchDataType)dataType + NS_SWIFT_UNAVAILABLE(""); @end @@ -370,7 +380,8 @@ __attribute__((deprecated("This API is experimental."))) strides:(NSArray *)strides dimensionOrder:(NSArray *)dimensionOrder dataType:(ExecuTorchDataType)dataType - shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism; + shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism + NS_REFINED_FOR_SWIFT; /** * Initializes a tensor using an NSData object as the underlying data buffer with dynamic bound shape. @@ -386,7 +397,8 @@ __attribute__((deprecated("This API is experimental."))) shape:(NSArray *)shape strides:(NSArray *)strides dimensionOrder:(NSArray *)dimensionOrder - dataType:(ExecuTorchDataType)dataType; + dataType:(ExecuTorchDataType)dataType + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor using an NSData object as the underlying data buffer, specifying shape, data type, and explicit shape dynamism. @@ -400,7 +412,8 @@ __attribute__((deprecated("This API is experimental."))) - (instancetype)initWithData:(NSData *)data shape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType - shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism; + shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor using an NSData object as the underlying data buffer, specifying only the shape and data type. @@ -412,7 +425,8 @@ __attribute__((deprecated("This API is experimental."))) */ - (instancetype)initWithData:(NSData *)data shape:(NSArray *)shape - dataType:(ExecuTorchDataType)dataType; + dataType:(ExecuTorchDataType)dataType + NS_SWIFT_UNAVAILABLE(""); @end @@ -437,7 +451,7 @@ __attribute__((deprecated("This API is experimental."))) dimensionOrder:(NSArray *)dimensionOrder dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(init(_:shape:strides:dimensionOrder:dataType:shapeDynamism:)); + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with an array of scalar values, specifying shape, strides, dimension order, and data type, @@ -455,7 +469,7 @@ __attribute__((deprecated("This API is experimental."))) strides:(NSArray *)strides dimensionOrder:(NSArray *)dimensionOrder dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(init(_:shape:strides:dimensionOrder:dataType:)); + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with an array of scalar values, specifying the desired shape, data type, and explicit shape dynamism. @@ -470,7 +484,7 @@ __attribute__((deprecated("This API is experimental."))) shape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(init(_:shape:dataType:shapeDynamism:)); + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with an array of scalar values and a specified shape, @@ -484,7 +498,7 @@ __attribute__((deprecated("This API is experimental."))) - (instancetype)initWithScalars:(NSArray *)scalars shape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(init(_:shape:dataType:)); + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with an array of scalar values, specifying the tensor data type and explicit shape dynamism. @@ -498,7 +512,7 @@ __attribute__((deprecated("This API is experimental."))) - (instancetype)initWithScalars:(NSArray *)scalars dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(init(_:dataType:shapeDynamism:)); + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with an array of scalar values, specifying the tensor data type. @@ -510,7 +524,7 @@ __attribute__((deprecated("This API is experimental."))) */ - (instancetype)initWithScalars:(NSArray *)scalars dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(init(_:dataType:)); + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with an array of scalar values, a specified shape and explicit shape dynamism. @@ -524,7 +538,7 @@ __attribute__((deprecated("This API is experimental."))) - (instancetype)initWithScalars:(NSArray *)scalars shape:(NSArray *)shape shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(init(_:shape:shapeDynamism:)); + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with an array of scalar values and a specified shape. @@ -536,7 +550,7 @@ __attribute__((deprecated("This API is experimental."))) */ - (instancetype)initWithScalars:(NSArray *)scalars shape:(NSArray *)shape - NS_SWIFT_NAME(init(_:shape:)); + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with an array of scalar values, automatically deducing the tensor shape and data type. @@ -545,7 +559,7 @@ __attribute__((deprecated("This API is experimental."))) * @return An initialized ExecuTorchTensor instance with shape and data type deduced. */ - (instancetype)initWithScalars:(NSArray *)scalars - NS_SWIFT_NAME(init(_:)); + NS_SWIFT_UNAVAILABLE(""); @end @@ -559,7 +573,8 @@ __attribute__((deprecated("This API is experimental."))) * @return An initialized ExecuTorchTensor instance representing the scalar. */ - (instancetype)initWithScalar:(NSNumber *)scalar - dataType:(ExecuTorchDataType)dataType NS_SWIFT_NAME(init(_:dataType:)); + dataType:(ExecuTorchDataType)dataType + NS_REFINED_FOR_SWIFT; /** * Initializes a tensor with a single scalar value, automatically deducing its data type. @@ -567,7 +582,8 @@ __attribute__((deprecated("This API is experimental."))) * @param scalar An NSNumber representing the scalar value. * @return An initialized ExecuTorchTensor instance representing the scalar. */ -- (instancetype)initWithScalar:(NSNumber *)scalar NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithScalar:(NSNumber *)scalar + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with a byte scalar value. @@ -575,7 +591,8 @@ __attribute__((deprecated("This API is experimental."))) * @param scalar A uint8_t value. * @return An initialized ExecuTorchTensor instance. */ -- (instancetype)initWithByte:(uint8_t)scalar NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithByte:(uint8_t)scalar + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with a char scalar value. @@ -583,7 +600,8 @@ __attribute__((deprecated("This API is experimental."))) * @param scalar An int8_t value. * @return An initialized ExecuTorchTensor instance. */ -- (instancetype)initWithChar:(int8_t)scalar NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithChar:(int8_t)scalar + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with a short scalar value. @@ -591,7 +609,8 @@ __attribute__((deprecated("This API is experimental."))) * @param scalar An int16_t value. * @return An initialized ExecuTorchTensor instance. */ -- (instancetype)initWithShort:(int16_t)scalar NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithShort:(int16_t)scalar + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with an int scalar value. @@ -599,7 +618,8 @@ __attribute__((deprecated("This API is experimental."))) * @param scalar An int32_t value. * @return An initialized ExecuTorchTensor instance. */ -- (instancetype)initWithInt:(int32_t)scalar NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithInt:(int32_t)scalar + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with a long scalar value. @@ -607,7 +627,8 @@ __attribute__((deprecated("This API is experimental."))) * @param scalar An int64_t value. * @return An initialized ExecuTorchTensor instance. */ -- (instancetype)initWithLong:(int64_t)scalar NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithLong:(int64_t)scalar + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with a float scalar value. @@ -615,7 +636,8 @@ __attribute__((deprecated("This API is experimental."))) * @param scalar A float value. * @return An initialized ExecuTorchTensor instance. */ -- (instancetype)initWithFloat:(float)scalar NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithFloat:(float)scalar + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with a double scalar value. @@ -623,7 +645,8 @@ __attribute__((deprecated("This API is experimental."))) * @param scalar A double value. * @return An initialized ExecuTorchTensor instance. */ -- (instancetype)initWithDouble:(double)scalar NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithDouble:(double)scalar + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with a boolean scalar value. @@ -631,7 +654,8 @@ __attribute__((deprecated("This API is experimental."))) * @param scalar A BOOL value. * @return An initialized ExecuTorchTensor instance. */ -- (instancetype)initWithBool:(BOOL)scalar NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithBool:(BOOL)scalar + NS_SWIFT_UNAVAILABLE(""); /** * Initializes a tensor with a uint16 scalar value. @@ -639,7 +663,8 @@ __attribute__((deprecated("This API is experimental."))) * @param scalar A uint16_t value. * @return An initialized ExecuTorchTensor instance. */ -- (instancetype)initWithUInt16:(uint16_t)scalar NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithUInt16:(uint16_t)scalar + NS_SWIFT_NAME(init(_:)); /** * Initializes a tensor with a uint32 scalar value. @@ -647,7 +672,8 @@ __attribute__((deprecated("This API is experimental."))) * @param scalar A uint32_t value. * @return An initialized ExecuTorchTensor instance. */ -- (instancetype)initWithUInt32:(uint32_t)scalar NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithUInt32:(uint32_t)scalar + NS_SWIFT_NAME(init(_:)); /** * Initializes a tensor with a uint64 scalar value. @@ -655,7 +681,8 @@ __attribute__((deprecated("This API is experimental."))) * @param scalar A uint64_t value. * @return An initialized ExecuTorchTensor instance. */ -- (instancetype)initWithUInt64:(uint64_t)scalar NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithUInt64:(uint64_t)scalar + NS_SWIFT_NAME(init(_:)); /** * Initializes a tensor with an NSInteger scalar value. @@ -663,7 +690,8 @@ __attribute__((deprecated("This API is experimental."))) * @param scalar An NSInteger value. * @return An initialized ExecuTorchTensor instance. */ -- (instancetype)initWithInteger:(NSInteger)scalar NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithInteger:(NSInteger)scalar + NS_SWIFT_NAME(init(_:)); /** * Initializes a tensor with an NSUInteger scalar value. @@ -671,7 +699,8 @@ __attribute__((deprecated("This API is experimental."))) * @param scalar An NSUInteger value. * @return An initialized ExecuTorchTensor instance. */ -- (instancetype)initWithUnsignedInteger:(NSUInteger)scalar NS_SWIFT_NAME(init(_:)); +- (instancetype)initWithUnsignedInteger:(NSUInteger)scalar + NS_SWIFT_NAME(init(_:)); @end @@ -692,7 +721,7 @@ __attribute__((deprecated("This API is experimental."))) strides:(NSArray *)strides dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(empty(shape:strides:dataType:shapeDynamism:)) + NS_REFINED_FOR_SWIFT NS_RETURNS_RETAINED; /** @@ -706,7 +735,7 @@ __attribute__((deprecated("This API is experimental."))) + (instancetype)emptyTensorWithShape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(empty(shape:dataType:shapeDynamism:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -718,7 +747,7 @@ __attribute__((deprecated("This API is experimental."))) */ + (instancetype)emptyTensorWithShape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(empty(shape:dataType:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -732,7 +761,7 @@ __attribute__((deprecated("This API is experimental."))) + (instancetype)emptyTensorLikeTensor:(ExecuTorchTensor *)tensor dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(empty(like:dataType:shapeDynamism:)) + NS_REFINED_FOR_SWIFT NS_RETURNS_RETAINED; /** @@ -744,7 +773,7 @@ __attribute__((deprecated("This API is experimental."))) */ + (instancetype)emptyTensorLikeTensor:(ExecuTorchTensor *)tensor dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(empty(like:dataType:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -754,7 +783,7 @@ __attribute__((deprecated("This API is experimental."))) * @return A new, empty ExecuTorchTensor instance with the same properties as the provided tensor. */ + (instancetype)emptyTensorLikeTensor:(ExecuTorchTensor *)tensor - NS_SWIFT_NAME(empty(like:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; @end @@ -778,7 +807,7 @@ __attribute__((deprecated("This API is experimental."))) strides:(NSArray *)strides dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(full(shape:scalar:strides:dataType:shapeDynamism:)) + NS_REFINED_FOR_SWIFT NS_RETURNS_RETAINED; /** @@ -794,7 +823,7 @@ __attribute__((deprecated("This API is experimental."))) scalar:(NSNumber *)scalar dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(full(shape:scalar:dataType:shapeDynamism:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -809,7 +838,7 @@ __attribute__((deprecated("This API is experimental."))) + (instancetype)fullTensorWithShape:(NSArray *)shape scalar:(NSNumber *)scalar dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(full(shape:scalar:dataType:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -825,7 +854,7 @@ __attribute__((deprecated("This API is experimental."))) scalar:(NSNumber *)scalar dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(full(like:scalar:dataType:shapeDynamism:)) + NS_REFINED_FOR_SWIFT NS_RETURNS_RETAINED; /** @@ -839,7 +868,7 @@ __attribute__((deprecated("This API is experimental."))) + (instancetype)fullTensorLikeTensor:(ExecuTorchTensor *)tensr scalar:(NSNumber *)scalar dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(full(like:scalar:dataType:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -851,7 +880,7 @@ __attribute__((deprecated("This API is experimental."))) */ + (instancetype)fullTensorLikeTensor:(ExecuTorchTensor *)tensr scalar:(NSNumber *)scalar - NS_SWIFT_NAME(full(like:scalar:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; @end @@ -871,7 +900,7 @@ __attribute__((deprecated("This API is experimental."))) + (instancetype)onesTensorWithShape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(ones(shape:dataType:shapeDynamism:)) + NS_REFINED_FOR_SWIFT NS_RETURNS_RETAINED; /** @@ -883,7 +912,7 @@ __attribute__((deprecated("This API is experimental."))) */ + (instancetype)onesTensorWithShape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(ones(shape:dataType:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -897,7 +926,7 @@ __attribute__((deprecated("This API is experimental."))) + (instancetype)onesTensorLikeTensor:(ExecuTorchTensor *)tensor dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(ones(like:dataType:shapeDynamism:)) + NS_REFINED_FOR_SWIFT NS_RETURNS_RETAINED; /** @@ -909,7 +938,7 @@ __attribute__((deprecated("This API is experimental."))) */ + (instancetype)onesTensorLikeTensor:(ExecuTorchTensor *)tensor dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(ones(like:dataType:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -919,7 +948,7 @@ __attribute__((deprecated("This API is experimental."))) * @return A new ExecuTorchTensor instance filled with ones. */ + (instancetype)onesTensorLikeTensor:(ExecuTorchTensor *)tensor - NS_SWIFT_NAME(ones(like:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; @end @@ -939,7 +968,7 @@ __attribute__((deprecated("This API is experimental."))) + (instancetype)zerosTensorWithShape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(zeros(shape:dataType:shapeDynamism:)) + NS_REFINED_FOR_SWIFT NS_RETURNS_RETAINED; /** @@ -951,7 +980,7 @@ __attribute__((deprecated("This API is experimental."))) */ + (instancetype)zerosTensorWithShape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(zeros(shape:dataType:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -965,7 +994,7 @@ __attribute__((deprecated("This API is experimental."))) + (instancetype)zerosTensorLikeTensor:(ExecuTorchTensor *)tensor dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(zeros(like:dataType:shapeDynamism:)) + NS_REFINED_FOR_SWIFT NS_RETURNS_RETAINED; /** @@ -977,7 +1006,7 @@ __attribute__((deprecated("This API is experimental."))) */ + (instancetype)zerosTensorLikeTensor:(ExecuTorchTensor *)tensor dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(zeros(like:dataType:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -987,7 +1016,7 @@ __attribute__((deprecated("This API is experimental."))) * @return A new ExecuTorchTensor instance filled with zeros. */ + (instancetype)zerosTensorLikeTensor:(ExecuTorchTensor *)tensor - NS_SWIFT_NAME(zeros(like:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; @end @@ -1009,7 +1038,7 @@ __attribute__((deprecated("This API is experimental."))) strides:(NSArray *)strides dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(rand(shape:strides:dataType:shapeDynamism:)) + NS_REFINED_FOR_SWIFT NS_RETURNS_RETAINED; /** @@ -1023,7 +1052,7 @@ __attribute__((deprecated("This API is experimental."))) + (instancetype)randomTensorWithShape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(rand(shape:dataType:shapeDynamism:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -1035,7 +1064,7 @@ __attribute__((deprecated("This API is experimental."))) */ + (instancetype)randomTensorWithShape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(rand(shape:dataType:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -1049,7 +1078,7 @@ __attribute__((deprecated("This API is experimental."))) + (instancetype)randomTensorLikeTensor:(ExecuTorchTensor *)tensor dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(rand(like:dataType:shapeDynamism:)) + NS_REFINED_FOR_SWIFT NS_RETURNS_RETAINED; /** @@ -1061,7 +1090,7 @@ __attribute__((deprecated("This API is experimental."))) */ + (instancetype)randomTensorLikeTensor:(ExecuTorchTensor *)tensor dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(rand(like:dataType:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -1071,7 +1100,7 @@ __attribute__((deprecated("This API is experimental."))) * @return A new ExecuTorchTensor instance filled with random values. */ + (instancetype)randomTensorLikeTensor:(ExecuTorchTensor *)tensor - NS_SWIFT_NAME(rand(like:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; @end @@ -1094,7 +1123,7 @@ __attribute__((deprecated("This API is experimental."))) strides:(NSArray *)strides dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(randn(shape:strides:dataType:shapeDynamism:)) + NS_REFINED_FOR_SWIFT NS_RETURNS_RETAINED; /** @@ -1109,7 +1138,7 @@ __attribute__((deprecated("This API is experimental."))) + (instancetype)randomNormalTensorWithShape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(randn(shape:dataType:shapeDynamism:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -1122,7 +1151,7 @@ __attribute__((deprecated("This API is experimental."))) */ + (instancetype)randomNormalTensorWithShape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(randn(shape:dataType:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -1137,7 +1166,7 @@ __attribute__((deprecated("This API is experimental."))) + (instancetype)randomNormalTensorLikeTensor:(ExecuTorchTensor *)tensor dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(randn(like:dataType:shapeDynamism:)) + NS_REFINED_FOR_SWIFT NS_RETURNS_RETAINED; /** @@ -1150,7 +1179,7 @@ __attribute__((deprecated("This API is experimental."))) */ + (instancetype)randomNormalTensorLikeTensor:(ExecuTorchTensor *)tensor dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(randn(like:dataType:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -1160,7 +1189,7 @@ __attribute__((deprecated("This API is experimental."))) * @return A new ExecuTorchTensor instance filled with values from a normal distribution. */ + (instancetype)randomNormalTensorLikeTensor:(ExecuTorchTensor *)tensor - NS_SWIFT_NAME(randn(like:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; @end @@ -1187,7 +1216,7 @@ __attribute__((deprecated("This API is experimental."))) strides:(NSArray *)strides dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(randint(low:high:shape:strides:dataType:shapeDynamism:)) + NS_REFINED_FOR_SWIFT NS_RETURNS_RETAINED; /** @@ -1206,7 +1235,7 @@ __attribute__((deprecated("This API is experimental."))) shape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(randint(low:high:shape:dataType:shapeDynamism:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -1223,7 +1252,7 @@ __attribute__((deprecated("This API is experimental."))) high:(NSInteger)high shape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(randint(low:high:shape:dataType:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -1242,7 +1271,7 @@ __attribute__((deprecated("This API is experimental."))) high:(NSInteger)high dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism - NS_SWIFT_NAME(randint(like:low:high:dataType:shapeDynamism:)) + NS_REFINED_FOR_SWIFT NS_RETURNS_RETAINED; /** @@ -1259,7 +1288,7 @@ __attribute__((deprecated("This API is experimental."))) low:(NSInteger)low high:(NSInteger)high dataType:(ExecuTorchDataType)dataType - NS_SWIFT_NAME(randint(like:low:high:dataType:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; /** @@ -1273,7 +1302,7 @@ __attribute__((deprecated("This API is experimental."))) + (instancetype)randomIntegerTensorLikeTensor:(ExecuTorchTensor *)tensor low:(NSInteger)low high:(NSInteger)high - NS_SWIFT_NAME(randint(like:low:high:)) + NS_SWIFT_UNAVAILABLE("") NS_RETURNS_RETAINED; @end diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index d38fb277bff..3cf06207b45 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -17,6 +17,86 @@ using namespace executorch::extension; using namespace executorch::runtime; +static inline NSString *dataTypeDescription(ExecuTorchDataType dataType) { + switch (dataType) { + case ExecuTorchDataTypeByte: + return @"byte"; + case ExecuTorchDataTypeChar: + return @"char"; + case ExecuTorchDataTypeShort: + return @"short"; + case ExecuTorchDataTypeInt: + return @"int"; + case ExecuTorchDataTypeLong: + return @"long"; + case ExecuTorchDataTypeHalf: + return @"half"; + case ExecuTorchDataTypeFloat: + return @"float"; + case ExecuTorchDataTypeDouble: + return @"double"; + case ExecuTorchDataTypeComplexHalf: + return @"complexHalf"; + case ExecuTorchDataTypeComplexFloat: + return @"complexFloat"; + case ExecuTorchDataTypeComplexDouble: + return @"complexDouble"; + case ExecuTorchDataTypeBool: + return @"bool"; + case ExecuTorchDataTypeQInt8: + return @"qint8"; + case ExecuTorchDataTypeQUInt8: + return @"quint8"; + case ExecuTorchDataTypeQInt32: + return @"qint32"; + case ExecuTorchDataTypeBFloat16: + return @"bfloat16"; + case ExecuTorchDataTypeQUInt4x2: + return @"quint4x2"; + case ExecuTorchDataTypeQUInt2x4: + return @"quint2x4"; + case ExecuTorchDataTypeBits1x8: + return @"bits1x8"; + case ExecuTorchDataTypeBits2x4: + return @"bits2x4"; + case ExecuTorchDataTypeBits4x2: + return @"bits4x2"; + case ExecuTorchDataTypeBits8: + return @"bits8"; + case ExecuTorchDataTypeBits16: + return @"bits16"; + case ExecuTorchDataTypeFloat8_e5m2: + return @"float8_e5m2"; + case ExecuTorchDataTypeFloat8_e4m3fn: + return @"float8_e4m3fn"; + case ExecuTorchDataTypeFloat8_e5m2fnuz: + return @"float8_e5m2fnuz"; + case ExecuTorchDataTypeFloat8_e4m3fnuz: + return @"float8_e4m3fnuz"; + case ExecuTorchDataTypeUInt16: + return @"uint16"; + case ExecuTorchDataTypeUInt32: + return @"uint32"; + case ExecuTorchDataTypeUInt64: + return @"uint64"; + default: + return @"undefined"; + } +} + +static inline NSString *shapeDynamismDescription(ExecuTorchShapeDynamism dynamism) { + switch (dynamism) { + case ExecuTorchShapeDynamismStatic: + return @"static"; + case ExecuTorchShapeDynamismDynamicBound: + return @"dynamicBound"; + case ExecuTorchShapeDynamismDynamicUnbound: + return @"dynamicUnbound"; + default: + return @"undefined"; + } +} + NSInteger ExecuTorchSizeOfDataType(ExecuTorchDataType dataType) { return elementSize(static_cast(dataType)); } @@ -150,6 +230,70 @@ - (BOOL)isEqual:(nullable id)other { return [self isEqualToTensor:(ExecuTorchTensor *)other]; } +- (NSString *)description { + std::ostringstream os; + os << "Tensor {"; + os << "\n dataType: " << dataTypeDescription(static_cast(_tensor->scalar_type())).UTF8String << ","; + os << "\n shape: ["; + const auto& sizes = _tensor->sizes(); + for (size_t index = 0; index < sizes.size(); ++index) { + if (index > 0) { + os << ","; + } + os << sizes[index]; + } + os << "],"; + os << "\n strides: ["; + const auto& strides = _tensor->strides(); + for (size_t index = 0; index < strides.size(); ++index) { + if (index > 0) { + os << ","; + } + os << strides[index]; + } + os << "],"; + os << "\n dimensionOrder: ["; + const auto& dim_order = _tensor->dim_order(); + for (size_t index = 0; index < dim_order.size(); ++index) { + if (index > 0) { + os << ","; + } + os << static_cast(dim_order[index]); + } + os << "],"; + os << "\n shapeDynamism: " << shapeDynamismDescription(static_cast(_tensor->shape_dynamism())).UTF8String << ","; + auto const count = _tensor->numel(); + os << "\n count: " << count << ","; + os << "\n scalars: ["; + ET_SWITCH_REALHBBF16_TYPES( + static_cast(_tensor->scalar_type()), + nullptr, + "description", + CTYPE, + [&] { + auto const *pointer = reinterpret_cast(_tensor->unsafeGetTensorImpl()->data()); + auto const countToPrint = std::min(count, (ssize_t)100); + for (size_t index = 0; index < countToPrint; ++index) { + if (index > 0) { + os << ","; + } + if constexpr (std::is_same_v || + std::is_same_v) { + os << static_cast(pointer[index]); + } else { + os << pointer[index]; + } + } + if (count > countToPrint) { + os << ",..."; + } + } + ); + os << "]"; + os << "\n}"; + return @(os.str().c_str()); +} + @end @implementation ExecuTorchTensor (BytesNoCopy) diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchValue.h b/extension/apple/ExecuTorch/Exported/ExecuTorchValue.h index d4c6d7168d3..4d09d826f1d 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchValue.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchValue.h @@ -174,7 +174,7 @@ __attribute__((deprecated("This API is experimental."))) * @return A new ExecuTorchValue instance with a tag of ExecuTorchValueTagTensor. */ + (instancetype)valueWithTensor:(ExecuTorchTensor *)value - NS_REFINED_FOR_SWIFT + NS_SWIFT_NAME(init(_:)) NS_RETURNS_RETAINED; /** diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchValue.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchValue.mm index dd4eed7157e..6ba03dc50f9 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchValue.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchValue.mm @@ -29,6 +29,36 @@ static inline ExecuTorchValueTag deduceValueTag(NSNumber *number) { } } +static inline NSString *valueTagDescription(ExecuTorchValueTag tag) { + switch (tag) { + case ExecuTorchValueTagNone: + return @"none"; + case ExecuTorchValueTagTensor: + return @"tensor"; + case ExecuTorchValueTagString: + return @"string"; + case ExecuTorchValueTagDouble: + return @"double"; + case ExecuTorchValueTagInteger: + return @"integer"; + case ExecuTorchValueTagBoolean: + return @"boolean"; + case ExecuTorchValueTagBooleanList: + return @"boolean_list"; + case ExecuTorchValueTagDoubleList: + return @"double_list"; + case ExecuTorchValueTagIntegerList: + return @"integer_list"; + case ExecuTorchValueTagTensorList: + return @"tensor_list"; + case ExecuTorchValueTagScalarList: + return @"scalar_list"; + case ExecuTorchValueTagOptionalTensorList: + return @"optional_tensor_list"; + } + return @"undefined"; +} + @interface ExecuTorchValue () - (instancetype)initWithTag:(ExecuTorchValueTag)tag @@ -195,4 +225,24 @@ - (BOOL)isEqual:(nullable id)other { return [self isEqualToValue:(ExecuTorchValue *)other]; } +- (NSString *)description { + NSMutableString *string = [NSMutableString new]; + [string appendString:@"Value {"]; + [string appendFormat:@"\n tag: %@", valueTagDescription(_tag)]; + [string appendString:@","]; + [string appendString:@"\n value: "]; + if (_value) { + NSString *valueDescription = [_value description]; + [string appendString:[_value description]]; + [string replaceOccurrencesOfString:@"\n" + withString:@"\n " + options:0 + range:NSMakeRange(string.length - valueDescription.length, valueDescription.length)]; + } else { + [string appendString:@"nil"]; + } + [string appendString:@"\n}"]; + return string; +} + @end diff --git a/extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/project.pbxproj b/extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/project.pbxproj index 47a7af09dbd..c9b68f250c1 100644 --- a/extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/project.pbxproj +++ b/extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/project.pbxproj @@ -38,6 +38,7 @@ F292B01D2D88AF3500BE6839 /* bpe_tokenizer_base.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B0162D88AF3500BE6839 /* bpe_tokenizer_base.cpp */; }; F292B0202D88AF3500BE6839 /* llama2c_tokenizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B0172D88AF3500BE6839 /* llama2c_tokenizer.cpp */; }; F292B0212D88AF3500BE6839 /* tiktoken.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B01A2D88AF3500BE6839 /* tiktoken.cpp */; }; + F2E1B5172E03AC19002C9718 /* sentencepiece.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F2E1B5162E03AC19002C9718 /* sentencepiece.cpp */; }; /* End PBXBuildFile section */ /* Begin PBXContainerItemProxy section */ @@ -110,6 +111,7 @@ F292B0292D88AF4800BE6839 /* result.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = result.h; sourceTree = ""; }; F292B02B2D88AF4800BE6839 /* tiktoken.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = tiktoken.h; sourceTree = ""; }; F292B02D2D88AF4800BE6839 /* tokenizer.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = tokenizer.h; sourceTree = ""; }; + F2E1B5162E03AC19002C9718 /* sentencepiece.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = sentencepiece.cpp; path = src/sentencepiece.cpp; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -183,6 +185,7 @@ 032A74022CAFBB7800932D36 /* tokenizers */ = { isa = PBXGroup; children = ( + F2E1B5162E03AC19002C9718 /* sentencepiece.cpp */, 3C6ABD322DFA27DE0015DE55 /* regex_lookahead.cpp */, 30AA4B592DC0766800B1BE50 /* hf_tokenizer.cpp */, 30AA4B5A2DC0766800B1BE50 /* pcre2_regex.cpp */, @@ -400,7 +403,7 @@ ); runOnlyForDeploymentPostprocessing = 0; shellPath = /bin/sh; - shellScript = "set -e\n\nif ! command -v cmake &> /dev/null\nthen\n echo \"Cmake not found, please install Cmake. \\n1. Download Cmake.app from https://cmake.org/download with version > 3.19. \\n2. Install it to Applications/ folder and run sudo /Applications/CMake.app/Contents/bin/cmake-gui --install to install CMake commandline tools.\"\n exit 1\nfi\n\nCMAKE_DIR=\"$TEMP_DIR/cmake\"\nrm -rf \"$CMAKE_DIR\"\n\nPLATFORM=\"SIMULATORARM64\"\nDEPLOYMENT_TARGET=\"17.0\"\n\nif [[ \"$PLATFORM_NAME\" == *\"iphoneos\"* ]]; then\n PLATFORM=\"OS64\"\nelif [[ \"$PLATFORM_NAME\" == *\"macos\"* ]]; then\n PLATFORM=\"MAC_ARM64\"\n DEPLOYMENT_TARGET=\"12.0\"\nfi\n\ncmake_build() {\n local src_dir=$1\n local target=$2\n shift 2\n local extra_args=(\"$@\")\n local build_dir=\"$CMAKE_DIR/build/$(basename \"$src_dir\")\"\n\n mkdir -p \"$build_dir\" && cd \"$build_dir\"\n\n if [[ \"$PLATFORM\" == \"MAC_ARM64\" ]]; then\n extra_args+=(-DCMAKE_INSTALL_BUNDLEDIR=\"${CMAKE_DIR}/bin\")\n extra_args+=(-DCMAKE_MACOSX_BUNDLE=OFF)\n fi\n cmake -G Xcode \\\n -DCMAKE_BUILD_TYPE=\"Release\" \\\n -DCMAKE_CXX_STANDARD=17 \\\n -DCMAKE_TOOLCHAIN_FILE=\"$SRCROOT/../../../../third-party/ios-cmake/ios.toolchain.cmake\" \\\n -DCMAKE_XCODE_ATTRIBUTE_CLANG_CXX_LANGUAGE_STANDARD=\"c++17\" \\\n -DCMAKE_XCODE_ATTRIBUTE_CLANG_CXX_LIBRARY=\"libc++\" \\\n -DPLATFORM=\"$PLATFORM\" \\\n -DDEPLOYMENT_TARGET=\"$DEPLOYMENT_TARGET\" \\\n -DCMAKE_INSTALL_PREFIX=\"$CMAKE_DIR\" \\\n \"${extra_args[@]}\" \\\n \"$src_dir\"\n cmake --build . --config \"Release\" --target \"$target\"\n if [[ \"$target\" == \"install\" ]]; then\n cmake --install . --prefix \"$CMAKE_DIR\"\n fi\n}\n\ncmake_build \"$SRCROOT/../../../llm/tokenizers/third-party/abseil-cpp\" \"install\" \\\n -DABSL_PROPAGATE_CXX_STD=ON\n\ncmake_build \"$SRCROOT/../../../llm/tokenizers/third-party/re2\" \"install\"\n\ncmake_build \"$SRCROOT/../../../llm/tokenizers/third-party/pcre2\" \"install\" \\\n -DPCRE2_BUILD_PCRE2_8=ON \\\n -DPCRE2_BUILD_PCRE2_16=OFF \\\n -DPCRE2_BUILD_PCRE2_32=OFF \\\n -DPCRE2_BUILD_TESTS=OFF \\\n -DPCRE2_BUILD_PCRE2GREP=OFF \\\n -DPCRE2_BUILD_PCRE2TEST=OFF \\\n -DPCRE2_BUILD_PCRE2GPERF=OFF \\\n -DPCRE2_BUILD_DOCS=OFF \\\n -DPCRE2_BUILD_LIBPCRE2_PDB=OFF \\\n -DSUPPORT_REGEX_LOOKAHEAD=ON\n \ncmake_build \"$SRCROOT/../../../llm/tokenizers/third-party/sentencepiece\" \"sentencepiece-static\" \\\n -DSPM_ENABLE_SHARED=OFF\n \ncmake_build \"$SRCROOT/../../../llm/tokenizers/third-party/llama.cpp-unicode\" \"install\"\n \n# Include the single header for json.\nmkdir -p \"$CMAKE_DIR/include/nlohmann\"\ncp \"$SRCROOT/../../../llm/tokenizers/third-party/json/single_include/nlohmann/json.hpp\" \"$CMAKE_DIR/include/nlohmann/json.hpp\"\n\necho \"$(find $CMAKE_DIR/lib -name \"*.a\" | sed -E 's|^.*/lib([^/]+)\\.a|-l\\1|g' | tr '\\n' ' ')\" > \"$CMAKE_DIR/linker_flags\"\n"; + shellScript = "set -e\n\nif ! command -v cmake &> /dev/null\nthen\n echo \"Cmake not found, please install Cmake. \\n1. Download Cmake.app from https://cmake.org/download with version > 3.19. \\n2. Install it to Applications/ folder and run sudo /Applications/CMake.app/Contents/bin/cmake-gui --install to install CMake commandline tools.\"\n exit 1\nfi\n\nCMAKE_DIR=\"$TEMP_DIR/cmake\"\nrm -rf \"$CMAKE_DIR\"\n\nPLATFORM=\"SIMULATORARM64\"\nDEPLOYMENT_TARGET=\"17.0\"\n\nif [[ \"$PLATFORM_NAME\" == *\"iphoneos\"* ]]; then\n PLATFORM=\"OS64\"\nelif [[ \"$PLATFORM_NAME\" == *\"macos\"* ]]; then\n PLATFORM=\"MAC_ARM64\"\n DEPLOYMENT_TARGET=\"12.0\"\nfi\n\ncmake_build() {\n local src_dir target do_install=0\n local extra_args=()\n local build_dir\n # Parse arguments\n src_dir=\"$1\"\n shift\n target=\"$1\"\n if [[ \"$target\" == \"install\" ]]; then\n # Usage: cmake_build install [extra_args...]\n do_install=1\n shift\n else\n # Usage: cmake_build [install] [extra_args...]\n shift\n if [[ \"$1\" == \"install\" ]]; then\n do_install=1\n shift\n fi\n fi\n # Collect any remaining arguments as extra_args\n extra_args=(\"$@\")\n build_dir=\"$CMAKE_DIR/build/$(basename \"$src_dir\")\"\n mkdir -p \"$build_dir\" || { echo \"Failed to create build dir\"; return 1; }\n pushd \"$build_dir\" > /dev/null || { echo \"Failed to enter build dir\"; return 1; }\n # Platform-specific CMake args\n if [[ \"$PLATFORM\" == \"MAC_ARM64\" ]]; then\n extra_args+=(-DCMAKE_INSTALL_BUNDLEDIR=\"${CMAKE_DIR}/bin\")\n extra_args+=(-DCMAKE_MACOSX_BUNDLE=OFF)\n fi\n # Configure\n cmake -G Xcode \\\n -DCMAKE_BUILD_TYPE=\"Release\" \\\n -DCMAKE_CXX_STANDARD=17 \\\n -DCMAKE_TOOLCHAIN_FILE=\"$SRCROOT/../../../../third-party/ios-cmake/ios.toolchain.cmake\" \\\n -DCMAKE_XCODE_ATTRIBUTE_CLANG_CXX_LANGUAGE_STANDARD=\"c++17\" \\\n -DCMAKE_XCODE_ATTRIBUTE_CLANG_CXX_LIBRARY=\"libc++\" \\\n -DPLATFORM=\"$PLATFORM\" \\\n -DDEPLOYMENT_TARGET=\"$DEPLOYMENT_TARGET\" \\\n -DCMAKE_INSTALL_PREFIX=\"$CMAKE_DIR\" \\\n \"${extra_args[@]}\" \\\n \"$src_dir\" || { echo \"CMake configure failed\"; popd > /dev/null; return 1; }\n # Build\n cmake --build . --config \"Release\" --target $target\n # Install if requested\n if [[ $do_install -eq 1 ]]; then\n cmake --install . --prefix \"$CMAKE_DIR\" || echo \"Ignoring install failures\"\n fi\n}\n\ncmake_build \"$SRCROOT/../../../llm/tokenizers/third-party/abseil-cpp\" \"install\" \\\n -DABSL_PROPAGATE_CXX_STD=ON\n\ncmake_build \"$SRCROOT/../../../llm/tokenizers/third-party/re2\" \"install\"\n\ncmake_build \"$SRCROOT/../../../llm/tokenizers/third-party/pcre2\" \"install\" \\\n -DPCRE2_BUILD_PCRE2_8=ON \\\n -DPCRE2_BUILD_PCRE2_16=OFF \\\n -DPCRE2_BUILD_PCRE2_32=OFF \\\n -DPCRE2_BUILD_TESTS=OFF \\\n -DPCRE2_BUILD_PCRE2GREP=OFF \\\n -DPCRE2_BUILD_PCRE2TEST=OFF \\\n -DPCRE2_BUILD_PCRE2GPERF=OFF \\\n -DPCRE2_BUILD_DOCS=OFF \\\n -DPCRE2_BUILD_LIBPCRE2_PDB=OFF \\\n -DSUPPORT_REGEX_LOOKAHEAD=ON\n \ncmake_build \"$SRCROOT/../../../llm/tokenizers/third-party/sentencepiece\" \"sentencepiece-static sentencepiece_train-static\" \"install\" \\\n -DSPM_ENABLE_SHARED=OFF \\\n -DSPM_BUILD_TEST=OFF \\\n -DCMAKE_SYSTEM_NAME=\"iOS\"\n \ncmake_build \"$SRCROOT/../../../llm/tokenizers/third-party/llama.cpp-unicode\" \"install\"\n \n# Include the single header for json.\nmkdir -p \"$CMAKE_DIR/include/nlohmann\"\ncp \"$SRCROOT/../../../llm/tokenizers/third-party/json/single_include/nlohmann/json.hpp\" \"$CMAKE_DIR/include/nlohmann/json.hpp\"\n\necho \"$(find $CMAKE_DIR/lib -name \"*.a\" | sed -E 's|^.*/lib([^/]+)\\.a|-l\\1|g' | tr '\\n' ' ')\" > \"$CMAKE_DIR/linker_flags\"\n"; }; /* End PBXShellScriptBuildPhase section */ @@ -426,6 +429,7 @@ F292B01D2D88AF3500BE6839 /* bpe_tokenizer_base.cpp in Sources */, F292B0202D88AF3500BE6839 /* llama2c_tokenizer.cpp in Sources */, F292B0212D88AF3500BE6839 /* tiktoken.cpp in Sources */, + F2E1B5172E03AC19002C9718 /* sentencepiece.cpp in Sources */, 03E7E6792CBDCAE900205E71 /* CoreMLTests.mm in Sources */, 032A74232CAFC1B300932D36 /* runner.cpp in Sources */, 03B2D37A2C8A515C0046936E /* GenericTests.mm in Sources */, diff --git a/extension/benchmark/apple/Benchmark/Tests/Tests.xcconfig b/extension/benchmark/apple/Benchmark/Tests/Tests.xcconfig index 0172f28b1bb..bf915abc25b 100644 --- a/extension/benchmark/apple/Benchmark/Tests/Tests.xcconfig +++ b/extension/benchmark/apple/Benchmark/Tests/Tests.xcconfig @@ -17,7 +17,9 @@ OTHER_LDFLAGS = $(inherited) \ HEADER_SEARCH_PATHS = $(inherited) \ $(SRCROOT)/../../../../.. \ $(TEMP_DIR)/cmake/include \ - $(SRCROOT)/../../../../extension/llm/tokenizers/include + $(SRCROOT)/../../../../extension/llm/tokenizers/include \ + $(SRCROOT)/../../../../extension/llm/tokenizers/third-party/sentencepiece \ + $(SRCROOT)/../../../../extension/llm/tokenizers/third-party/sentencepiece/src LIBRARY_SEARCH_PATHS = $(inherited) \ $(TEMP_DIR)/cmake/lib diff --git a/extension/data_loader/mmap_data_loader.cpp b/extension/data_loader/mmap_data_loader.cpp index 53fd7bdf624..10bd2f35f5e 100644 --- a/extension/data_loader/mmap_data_loader.cpp +++ b/extension/data_loader/mmap_data_loader.cpp @@ -150,10 +150,10 @@ void MunmapSegment(void* context, void* data, size_t size) { } } // namespace -Result MmapDataLoader::load( - size_t offset, - size_t size, - ET_UNUSED const DataLoader::SegmentInfo& segment_info) const { +/** + * Validates that file read range is within bounds. + */ +Error MmapDataLoader::validate_input(size_t offset, size_t size) const { ET_CHECK_OR_RETURN_ERROR( // Probably had its value moved to another instance. fd_ >= 0, @@ -173,6 +173,18 @@ Result MmapDataLoader::load( InvalidArgument, "Offset %zu too large for off_t", offset); + return Error::Ok; +} + +Result MmapDataLoader::load( + size_t offset, + size_t size, + ET_UNUSED const DataLoader::SegmentInfo& segment_info) const { + // Ensure read range is valid. + auto validation_err = validate_input(offset, size); + if (validation_err != Error::Ok) { + return validation_err; + } // mmap() will fail if the size is zero. if (size == 0) { @@ -193,14 +205,13 @@ Result MmapDataLoader::load( map_size = file_size_ - range.start; } - // Map the pages read-only. MAP_PRIVATE vs. MAP_SHARED doesn't matter since - // the data is read-only, but use PRIVATE just to further avoid accidentally - // modifying the file. + // Map the pages read-only. Use shared mappings so that other processes + // can also map the same pages and share the same memory. void* pages = ::mmap( nullptr, map_size, PROT_READ, - MAP_PRIVATE, + MAP_SHARED, fd_, static_cast(range.start)); ET_CHECK_OR_RETURN_ERROR( @@ -268,5 +279,69 @@ Result MmapDataLoader::size() const { return file_size_; } +Error MmapDataLoader::load_into( + size_t offset, + size_t size, + ET_UNUSED const SegmentInfo& segment_info, + void* buffer) const { + ET_CHECK_OR_RETURN_ERROR( + buffer != nullptr, InvalidArgument, "Buffer is null"); + + // Ensure read range is valid. + auto err = validate_input(offset, size); + if (err != Error::Ok) { + return err; + } + + // Nothing to copy. + if (size == 0) { + return Error::Ok; + } + + // Find the range of pages that covers the requested region. + Range range = + get_overlapping_pages(static_cast(offset), size, page_size_); + + size_t map_size = range.size; + if (range.start + map_size > file_size_) { + // Clamp to the end of the file. + // + // The Windows implementation of mmap uses CreateFileMapping which returns + // error STATUS_SECTION_TOO_BIG (0xc0000040) if we try to map past the end + // of the last page of a file mapped in as read-only. + map_size = file_size_ - range.start; + } + + // Map the pages read-only. MAP_PRIVATE vs. MAP_SHARED doesn't matter since + // the data is read-only, but use PRIVATE just to further avoid accidentally + // modifying the file. + void* pages = ::mmap( + nullptr, + map_size, + PROT_READ, + MAP_PRIVATE, + fd_, + static_cast(range.start)); + ET_CHECK_OR_RETURN_ERROR( + pages != MAP_FAILED, + AccessFailed, + "Failed to map %s: mmap(..., size=%zd, ..., fd=%d, offset=0x%zx)", + file_name_, + range.size, + fd_, + range.start); + + // Offset into mapped region. + const size_t map_delta = offset - range.start; + + // Copy data into caller's buffer. + std::memcpy(buffer, static_cast(pages) + map_delta, size); + + // Unmap mapped region. + ::munmap(pages, map_size); + + return Error::Ok; +} + } // namespace extension } // namespace executorch diff --git a/extension/data_loader/mmap_data_loader.h b/extension/data_loader/mmap_data_loader.h index c55f81a490b..c0496a39d4b 100644 --- a/extension/data_loader/mmap_data_loader.h +++ b/extension/data_loader/mmap_data_loader.h @@ -95,6 +95,13 @@ class MmapDataLoader final : public executorch::runtime::DataLoader { ET_NODISCARD executorch::runtime::Result size() const override; + ET_NODISCARD + executorch::runtime::Error load_into( + size_t offset, + size_t size, + ET_UNUSED const SegmentInfo& segment_info, + void* buffer) const override; + private: MmapDataLoader( int fd, @@ -113,6 +120,10 @@ class MmapDataLoader final : public executorch::runtime::DataLoader { MmapDataLoader& operator=(const MmapDataLoader&) = delete; MmapDataLoader& operator=(MmapDataLoader&&) = delete; + ET_NODISCARD executorch::runtime::Error validate_input( + size_t offset, + size_t size) const; + const char* const file_name_; // String data is owned by the instance. const size_t file_size_; const size_t page_size_; diff --git a/extension/data_loader/test/mmap_data_loader_test.cpp b/extension/data_loader/test/mmap_data_loader_test.cpp index c01b3454493..76b972c46d0 100644 --- a/extension/data_loader/test/mmap_data_loader_test.cpp +++ b/extension/data_loader/test/mmap_data_loader_test.cpp @@ -376,3 +376,56 @@ TEST_F(MmapDataLoaderTest, DEPRECATEDFrom) { ASSERT_EQ(total_size.error(), Error::Ok); EXPECT_EQ(*total_size, contents_size); } + +// Tests that load_into copies bytes correctly. +TEST_F(MmapDataLoaderTest, LoadIntoCopiesCorrectly) { + // Create a test string. + const char* test_text = "FILE_CONTENTS"; + const size_t text_size = std::strlen(test_text); + TempFile tf(test_text); + + // Wrap it in a loader. + Result mdl = MmapDataLoader::from(tf.path().c_str()); + ASSERT_EQ(mdl.error(), Error::Ok); + + // Destination buffer. + std::vector dst(text_size); + + // Call load_into() + Error err = mdl->load_into( + /*offset=*/0, + /*size=*/text_size, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program), + dst.data()); + ASSERT_EQ(err, Error::Ok); + + // Verify memory copied correctly. + EXPECT_EQ(0, std::memcmp(dst.data(), test_text, text_size)); +} + +// Tests that load_into copies offset slice correctly. +TEST_F(MmapDataLoaderTest, LoadIntoCopiesOffsetCorrectly) { + // Create a test string. + const char* contents = "ABCDEFGH"; + TempFile tf(contents); + + // Wrap it in a loader. + Result mdl = MmapDataLoader::from(tf.path().c_str()); + ASSERT_EQ(mdl.error(), Error::Ok); + + // Copying 3 bytes starting at offset 2 = "CDE" + const size_t offset = 2; + const size_t size = 3; + uint8_t dst[size]; + + // Call load_into() + Error err = mdl->load_into( + offset, + size, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program), + dst); + ASSERT_EQ(err, Error::Ok); + + // Verify memory copied correctly. + EXPECT_EQ(0, std::memcmp(dst, contents + offset, size)); +} \ No newline at end of file diff --git a/extension/llm/export/README.md b/extension/llm/export/README.md new file mode 100644 index 00000000000..1ac27306c86 --- /dev/null +++ b/extension/llm/export/README.md @@ -0,0 +1,137 @@ +# LLM Export API + +This directory contains the unified API for exporting Large Language Models (LLMs) to ExecuTorch. The `export_llm` module provides a streamlined interface to convert various LLM architectures to optimized `.pte` files for on-device inference. + +## Overview + +The LLM export process transforms a model from its original format to an optimized representation suitable for mobile and edge devices. This involves several key steps: + +1. **Model Instantiation**: Load the model architecture and weights from sources like Hugging Face +2. **Source Transformations**: Apply model-specific optimizations and quantization +3. **IR Export**: Convert to intermediate representations (EXIR, Edge dialect) +4. **Graph Transformations**: Apply backend-specific optimizations and PT2E quantization +5. **Backend Delegation**: Partition operations to hardware-specific backends (XNNPACK, CoreML, QNN, etc.) +6. **Serialization**: Export to final ExecuTorch `.pte` format + +## Supported Models + +- **Llama**: Llama 2, Llama 3, Llama 3.1, Llama 3.2 (1B, 3B, 8B variants) +- **Qwen**: Qwen 2.5, Qwen 3 (0.6B, 1.7B, 4B variants) +- **Phi**: Phi-3-Mini, Phi-4-Mini +- **Stories**: Stories110M (educational model) +- **SmolLM**: SmolLM2 + +## Usage + +The export API supports two configuration approaches: + +### Option 1: Hydra CLI Arguments + +Use structured configuration arguments directly on the command line: + +```bash +python -m extension.llm.export.export_llm \ + base.model_class=llama3 \ + model.use_sdpa_with_kv_cache=True \ + model.use_kv_cache=True \ + export.max_seq_length=128 \ + debug.verbose=True \ + backend.xnnpack.enabled=True \ + backend.xnnpack.extended_ops=True \ + quantization.qmode=8da4w +``` + +### Option 2: Configuration File + +Create a YAML configuration file and reference it: + +```bash +python -m extension.llm.export.export_llm --config my_config.yaml +``` + +Example `my_config.yaml`: +```yaml +base: + model_class: llama3 + tokenizer_path: /path/to/tokenizer.json + +model: + use_kv_cache: true + use_sdpa_with_kv_cache: true + enable_dynamic_shape: true + +export: + max_seq_length: 512 + output_dir: ./exported_models + output_name: llama3_optimized.pte + +quantization: + qmode: 8da4w + group_size: 32 + +backend: + xnnpack: + enabled: true + extended_ops: true + +debug: + verbose: true +``` + +**Important**: You cannot mix both approaches. Use either CLI arguments OR a config file, not both. + +## Example Commands + +### Export Qwen3 0.6B with XNNPACK backend and quantization +```bash +python -m extension.llm.export.export_llm \ + base.model_class=qwen3-0_6b \ + base.params=examples/models/qwen3/0_6b_config.json \ + base.metadata='{"get_bos_id": 151644, "get_eos_ids":[151645]}' \ + model.use_kv_cache=true \ + model.use_sdpa_with_kv_cache=true \ + model.dtype_override=FP32 \ + export.max_seq_length=512 \ + export.output_name=qwen3_0_6b.pte \ + quantization.qmode=8da4w \ + backend.xnnpack.enabled=true \ + backend.xnnpack.extended_ops=true \ + debug.verbose=true +``` + +### Export Phi-4-Mini with custom checkpoint +```bash +python -m extension.llm.export.export_llm \ + base.model_class=phi_4_mini \ + base.checkpoint=/path/to/phi4_checkpoint.pth \ + base.params=examples/models/phi-4-mini/config.json \ + base.metadata='{"get_bos_id":151643, "get_eos_ids":[151643]}' \ + model.use_kv_cache=true \ + model.use_sdpa_with_kv_cache=true \ + export.max_seq_length=256 \ + export.output_name=phi4_mini.pte \ + backend.xnnpack.enabled=true \ + debug.verbose=true +``` + +### Export with CoreML backend (iOS optimization) +```bash +python -m extension.llm.export.export_llm \ + base.model_class=llama3 \ + model.use_kv_cache=true \ + export.max_seq_length=128 \ + backend.coreml.enabled=true \ + backend.coreml.compute_units=ALL \ + quantization.pt2e_quantize=coreml_c4w \ + debug.verbose=true +``` + +## Configuration Options + +For a complete reference of all available configuration options, see the [LlmConfig class definition](../../../examples/models/llama/config/llm_config.py) which documents all supported parameters for base, model, export, quantization, backend, and debug configurations. + +## Further Reading + +- [Llama Examples](../../../examples/models/llama/README.md) - Comprehensive Llama export guide +- [LLM Runner](../runner/) - Running exported models +- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview \ No newline at end of file diff --git a/extension/llm/export/TARGETS b/extension/llm/export/TARGETS index a85370fc49c..7acf026a8da 100644 --- a/extension/llm/export/TARGETS +++ b/extension/llm/export/TARGETS @@ -47,6 +47,41 @@ runtime.python_library( ], ) +runtime.python_binary( + name = "export_llm", + srcs = [ + "export_llm.py", + ], + main_function = "executorch.extension.llm.export.export_llm.main", + preload_deps = [ + "//executorch/extension/llm/custom_ops:model_sharding_py", + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + "//executorch/kernels/quantized:aot_lib", + ], + deps = [ + "fbsource//third-party/pypi/hydra-core:hydra-core", + "fbsource//third-party/pypi/omegaconf:omegaconf", + "//executorch/examples/models/llama:export_library", + "//executorch/extension/pybindings:aten_lib", + ], +) + +runtime.python_library( + name = "export_llm_lib", + srcs = [ + "export_llm.py", + ], + deps = [ + "fbsource//third-party/pypi/hydra-core:hydra-core", + "fbsource//third-party/pypi/omegaconf:omegaconf", + "//executorch/examples/models/llama:export_library", + ], + visibility = [ + "//executorch/examples/...", + "//executorch/extension/llm/...", + ], +) + runtime.python_test( name = "export_passes_test", srcs = [ diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 8b81587c434..4128bfd8198 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -133,6 +133,19 @@ def __init__( self.output_dir = "." self._saved_pte_filename = None + def __post_init__(self): + """ + Post init function to update metadata based on dynamic shape + """ + dynamic_shape = self._get_dynamic_shape() + if dynamic_shape is not None: + token_dim = dynamic_shape[0][1] + if self.verbose: + logging.info( + f"Metadata 'get_max_seq_len' is being updated to match torch.export's dynamic shape max: {token_dim.max}" + ) + self.metadata["get_max_seq_len"] = token_dim.max + def set_output_dir(self, output_dir: str) -> "LLMEdgeManager": """ Set the directory where the .pte file will be saved. @@ -180,14 +193,19 @@ def _get_dynamic_shape(self) -> Any: if self.dynamic_shapes: return self.dynamic_shapes - dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1) if self.enable_dynamic_shape: if not self.use_kv_cache: # Only one input argument: tokens - self.dynamic_shapes = ({1: dim},) + # Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad + self.dynamic_shapes = ( + {1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)}, + ) else: # Two input arguments: tokens and input_pos but input_pos is static shape - self.dynamic_shapes = ({1: dim}, {"input_pos": {0: 1}}) + self.dynamic_shapes = ( + {1: torch.export.Dim("token_dim", max=self.max_seq_len)}, + {"input_pos": {0: 1}}, + ) else: # Two input arguments: tokens and input_pos but both are of static shape self.dynamic_shapes = None diff --git a/extension/llm/export/export_llm.py b/extension/llm/export/export_llm.py new file mode 100644 index 00000000000..e995b329f30 --- /dev/null +++ b/extension/llm/export/export_llm.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export an LLM with ExecuTorch. Currently follows the following steps: +1. Instantiate our custom PyTorch transformer definition from examples/llama/models/llama_transformer.py. +2. Load weights into the model. +3. Apply source transformations/TorchAO quantization. +4. Export model to intermediate IRs. +5. Graph transformations/PT2E quantization. +6. Partition graph and delegate to backend(s). +7. Export to final ExecuTorch .pte format. + +Example usage using full CLI arguments: +python -m extension.llm.export.export_llm \ + base.model_class="llama3" \ + model.use_sdpa_with_kv_cache=True \ + model.use_kv_cache=True \ + debug.verbose=True \ + backend.xnnpack.enabled=True \ + backend.xnnpack.extended_ops=True \ + quantization.qmode="8da4w" + +Example usage using config file: +python -m extension.llm.export.export_llm \ + --config example_llm_config.yaml +""" + +import argparse +import sys +from typing import Any, List, Tuple + +import hydra + +from executorch.examples.models.llama.config.llm_config import LlmConfig +from executorch.examples.models.llama.export_llama_lib import export_llama +from hydra.core.config_store import ConfigStore +from omegaconf import OmegaConf + +cs = ConfigStore.instance() +cs.store(name="llm_config", node=LlmConfig) + + +def parse_config_arg() -> Tuple[str, List[Any]]: + """First parse out the arg for whether to use Hydra or the old CLI.""" + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument("--config", type=str, help="Path to the LlmConfig file") + args, remaining = parser.parse_known_args() + return args.config, remaining + + +def pop_config_arg() -> str: + """ + Removes '--config' and its value from sys.argv. + Assumes --config is specified and argparse has already validated the args. + """ + idx = sys.argv.index("--config") + value = sys.argv[idx + 1] + del sys.argv[idx : idx + 2] + return value + + +@hydra.main(version_base=None, config_name="llm_config") +def hydra_main(llm_config: LlmConfig) -> None: + export_llama(OmegaConf.to_object(llm_config)) + + +def main() -> None: + config, remaining_args = parse_config_arg() + if config: + # Check if there are any remaining hydra CLI args when --config is specified + # This might change in the future to allow overriding config file values + if remaining_args: + raise ValueError( + "Cannot specify additional CLI arguments when using --config. " + f"Found: {remaining_args}. Use either --config file or hydra CLI args, not both." + ) + + config_file_path = pop_config_arg() + default_llm_config = LlmConfig() + llm_config_from_file = OmegaConf.load(config_file_path) + # Override defaults with values specified in the .yaml provided by --config. + merged_llm_config = OmegaConf.merge(default_llm_config, llm_config_from_file) + export_llama(merged_llm_config) + else: + hydra_main() + + +if __name__ == "__main__": + main() diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 20604bbf635..7b093a7f1a3 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -216,4 +216,6 @@ def get_qnn_partitioner( ), skip_node_id_set={}, skip_node_op_set=skip_node_op_set, + # TODO: if deprecated legacy export, skip_mutable_buffer can be set False + skip_mutable_buffer=True, ) diff --git a/extension/llm/export/test/test_builder.py b/extension/llm/export/test/test_builder.py index 7883480c1e7..8bf591813ec 100644 --- a/extension/llm/export/test/test_builder.py +++ b/extension/llm/export/test/test_builder.py @@ -88,7 +88,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> Non # Check first element (tokens dimension) self.assertIsInstance(result[0], dict) self.assertIn(1, result[0]) - self.assertEqual(result[0][1].max, self.max_seq_len - 1) + self.assertEqual(result[0][1].max, self.max_seq_len) # Check second element (input_pos dimension) self.assertIsInstance(result[1], dict) diff --git a/extension/llm/export/test/test_export_llm.py b/extension/llm/export/test/test_export_llm.py new file mode 100644 index 00000000000..7d17b7819d3 --- /dev/null +++ b/extension/llm/export/test/test_export_llm.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +from executorch.extension.llm.export.export_llm import ( + main, + parse_config_arg, + pop_config_arg, +) + + +class TestExportLlm(unittest.TestCase): + def test_parse_config_arg_with_config(self) -> None: + """Test parse_config_arg when --config is provided.""" + # Mock sys.argv to include --config + test_argv = ["script.py", "--config", "test_config.yaml", "extra", "args"] + with patch.object(sys, "argv", test_argv): + config_path, remaining = parse_config_arg() + self.assertEqual(config_path, "test_config.yaml") + self.assertEqual(remaining, ["extra", "args"]) + + def test_parse_config_arg_without_config(self) -> None: + """Test parse_config_arg when --config is not provided.""" + test_argv = ["script.py", "debug.verbose=True"] + with patch.object(sys, "argv", test_argv): + config_path, remaining = parse_config_arg() + self.assertIsNone(config_path) + self.assertEqual(remaining, ["debug.verbose=True"]) + + def test_pop_config_arg(self) -> None: + """Test pop_config_arg removes --config and its value from sys.argv.""" + test_argv = ["script.py", "--config", "test_config.yaml", "other", "args"] + with patch.object(sys, "argv", test_argv): + config_path = pop_config_arg() + self.assertEqual(config_path, "test_config.yaml") + self.assertEqual(sys.argv, ["script.py", "other", "args"]) + + @patch("executorch.extension.llm.export.export_llm.export_llama") + def test_with_config(self, mock_export_llama: MagicMock) -> None: + """Test main function with --config file and no hydra args.""" + # Create a temporary config file + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write( + """ +base: + model_class: llama2 + tokenizer_path: /path/to/tokenizer.json + preq_mode: preq_8da4w +model: + dtype_override: fp16 +export: + max_seq_length: 256 +quantization: + pt2e_quantize: xnnpack_dynamic + use_spin_quant: cuda +backend: + coreml: + quantize: c4w + compute_units: cpu_and_gpu +""" + ) + config_file = f.name + + try: + test_argv = ["script.py", "--config", config_file] + with patch.object(sys, "argv", test_argv): + main() + + # Verify export_llama was called with config + mock_export_llama.assert_called_once() + called_config = mock_export_llama.call_args[0][0] + self.assertEqual( + called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json" + ) + self.assertEqual(called_config["base"]["model_class"], "llama2") + self.assertEqual(called_config["base"]["preq_mode"].value, "8da4w") + self.assertEqual(called_config["model"]["dtype_override"].value, "fp16") + self.assertEqual(called_config["export"]["max_seq_length"], 256) + self.assertEqual( + called_config["quantization"]["pt2e_quantize"].value, "xnnpack_dynamic" + ) + self.assertEqual( + called_config["quantization"]["use_spin_quant"].value, "cuda" + ) + self.assertEqual( + called_config["backend"]["coreml"]["quantize"].value, "c4w" + ) + self.assertEqual( + called_config["backend"]["coreml"]["compute_units"].value, "cpu_and_gpu" + ) + finally: + os.unlink(config_file) + + def test_with_cli_args(self) -> None: + """Test main function with only hydra CLI args.""" + test_argv = ["script.py", "debug.verbose=True"] + with patch.object(sys, "argv", test_argv): + with patch( + "executorch.extension.llm.export.export_llm.hydra_main" + ) as mock_hydra: + main() + mock_hydra.assert_called_once() + + def test_config_with_cli_args_error(self) -> None: + """Test that --config rejects additional CLI arguments to prevent mixing approaches.""" + # Create a temporary config file + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write("base:\n checkpoint: /path/to/checkpoint.pth") + config_file = f.name + + try: + test_argv = ["script.py", "--config", config_file, "debug.verbose=True"] + with patch.object(sys, "argv", test_argv): + with self.assertRaises(ValueError) as cm: + main() + + error_msg = str(cm.exception) + self.assertIn( + "Cannot specify additional CLI arguments when using --config", + error_msg, + ) + finally: + os.unlink(config_file) + + def test_config_rejects_multiple_cli_args(self) -> None: + """Test that --config rejects multiple CLI arguments (not just single ones).""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write("export:\n max_seq_length: 128") + config_file = f.name + + try: + test_argv = [ + "script.py", + "--config", + config_file, + "debug.verbose=True", + "export.output_dir=/tmp", + ] + with patch.object(sys, "argv", test_argv): + with self.assertRaises(ValueError): + main() + finally: + os.unlink(config_file) + + +if __name__ == "__main__": + unittest.main() diff --git a/extension/llm/runner/irunner.h b/extension/llm/runner/irunner.h index c3ed668a4be..4c2efc91203 100644 --- a/extension/llm/runner/irunner.h +++ b/extension/llm/runner/irunner.h @@ -49,6 +49,10 @@ struct GenerationConfig { // Temperature for sampling (higher = more random) float temperature = 0.8f; + // Number of eos and bos to add to the prompt + int32_t num_bos = 0; + int32_t num_eos = 0; + /** * Resolve the maximum number of new tokens to generate based on constraints. * @@ -121,6 +125,23 @@ class ET_EXPERIMENTAL IRunner { std::function token_callback, std::function stats_callback) = 0; + /** + * Generate text based on the provided prompt and generation config, from a + * given position in KV cache. + * + * @param prompt The input prompt to generate from + * @param start_pos The starting position in KV cache of the input + * @param config Generation configuration parameters + * @param token_callback Callback function called for each generated token + * @param stats_callback Callback function for generation statistics + * @return Error::Ok if successful, an error otherwise + */ + virtual runtime::Error generate_from_pos( + const std::string& prompt, + int64_t start_pos, + const GenerationConfig& config, + std::function token_callback, + std::function stats_callback) = 0; /** * Stop the generation process. */ diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 2e8231748ed..244515112ac 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -103,7 +103,7 @@ def define_common_targets(): ":text_token_generator" + aten_suffix, "//pytorch/tokenizers:hf_tokenizer", "//pytorch/tokenizers:llama2c_tokenizer", - # "//pytorch/tokenizers:sentencepiece", # TODO(larryliu0820) Make sure this compiles in xplat. + "//pytorch/tokenizers:sentencepiece", "//pytorch/tokenizers:tiktoken", ], ) diff --git a/extension/llm/runner/test/CMakeLists.txt b/extension/llm/runner/test/CMakeLists.txt index ac46f0021fb..15b4d005f9d 100644 --- a/extension/llm/runner/test/CMakeLists.txt +++ b/extension/llm/runner/test/CMakeLists.txt @@ -17,13 +17,10 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) -set(_test_srcs test_generation_config.cpp test_text_llm_runner.cpp) +set(_test_srcs test_generation_config.cpp test_text_llm_runner.cpp + test_text_prefiller.cpp +) et_cxx_test( - test_runner - SOURCES - ${_test_srcs} - EXTRA_LIBS - executorch - extension_llm_runner + test_runner SOURCES ${_test_srcs} EXTRA_LIBS executorch extension_llm_runner ) diff --git a/extension/llm/runner/test/targets.bzl b/extension/llm/runner/test/targets.bzl index a5c8be7b6de..8bc3d4cc100 100644 --- a/extension/llm/runner/test/targets.bzl +++ b/extension/llm/runner/test/targets.bzl @@ -27,3 +27,12 @@ def define_common_targets(): "//executorch/runtime/core/exec_aten/testing_util:tensor_util", ], ) + + runtime.cxx_test( + name = "test_text_prefiller", + srcs = ["test_text_prefiller.cpp"], + deps = [ + "//executorch/extension/llm/runner:runner_lib", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + ], + ) diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index a9c2c680609..02f04a69b38 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -322,3 +322,58 @@ TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) { // Verify is_loaded returns true EXPECT_TRUE(runner.is_loaded()); } + +// Test that generate_from_pos() errors out when max_new_tokens is negative +TEST_F(RunnerTest, GenerateFromPosErrorsWithNegativeMaxNewTokens) { + // Create mock instances using helper functions + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + // Set up expectations for the tokenizer encode method + EXPECT_CALL(*tokenizer, encode(_, _, _)) + .WillOnce(Return(::tokenizers::Result>( + std::vector{1, 2, 3}))); + + // Set up expectations for load methods + EXPECT_CALL(*text_prefiller, is_loaded()).WillRepeatedly(Return(true)); + + std::unique_ptr stats = + std::make_unique(); + // Create a real TextTokenGenerator + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), stats.get()); + + // Create a Runner with our mocked components + TextLLMRunner runner( + { + {"enable_dynamic_shape", false}, + {"get_max_seq_len", 10}, + {"get_max_context_len", 10}, + {"use_kv_cache", true}, + }, + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::make_unique(), + std::move(text_decoder_runner), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(text_token_generator), + std::move(stats)); + + // Load + runner.load(); + + // Set up the generation config with a negative max_new_tokens value + GenerationConfig config; + config.max_new_tokens = 5; + config.echo = false; + + // num_prompt_tokens = 3 + // max_context_len = 10 + // start_pos = 8, this should fail because 10 - 8 > 3, even though + // config.max_new_tokens = 5 > 3, it's still a failure. + Error err = runner.generate_from_pos("test prompt", 8, config); + + // Verify that an InvalidArgument error is returned + EXPECT_EQ(err, Error::InvalidArgument); +} diff --git a/extension/llm/runner/test/test_text_prefiller.cpp b/extension/llm/runner/test/test_text_prefiller.cpp new file mode 100644 index 00000000000..b786fc71978 --- /dev/null +++ b/extension/llm/runner/test/test_text_prefiller.cpp @@ -0,0 +1,306 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated + */ + +#include +#include +#include +#include +#include +#include + +using namespace ::testing; +using executorch::extension::llm::TextDecoderRunner; +using executorch::extension::llm::TextPrefiller; +using executorch::runtime::Error; +using executorch::runtime::Result; +using executorch::runtime::testing::TensorFactory; + +// Mock class for TextDecoderRunner +class MockTextDecoderRunner : public TextDecoderRunner { + public: + MockTextDecoderRunner() : TextDecoderRunner(nullptr, false) {} + MOCK_METHOD( + Result, + step, + (executorch::extension::TensorPtr&, executorch::extension::TensorPtr&), + ()); + MOCK_METHOD(bool, is_method_loaded, (), ()); + MOCK_METHOD(Result, prefill, (std::vector&, int64_t), ()); + MOCK_METHOD(::executorch::runtime::Error, load, (), ()); +}; + +// Test fixture for TextPrefiller tests +class TextPrefillerTest : public Test { + protected: + void SetUp() override { + executorch::runtime::runtime_init(); + // Set up default behavior for the text decoder runner + ON_CALL(text_decoder_runner_, is_method_loaded()) + .WillByDefault(Return(true)); + ON_CALL(text_decoder_runner_, step) + .WillByDefault([&](executorch::extension::TensorPtr&, + executorch::extension::TensorPtr&) { + return Result(tensor); + }); + } + + // Helper method to create a TextPrefiller with specific parameters + std::unique_ptr createTextPrefiller( + int64_t max_seq_len, + bool use_kv_cache = true, + bool enable_parallel_prefill = false) { + return std::make_unique( + &text_decoder_runner_, + use_kv_cache, + enable_parallel_prefill, + max_seq_len); + } + + // Create a mock TextPrefiller that allows us to mock prefill_chunk calls + class MockTextPrefiller : public TextPrefiller { + public: + MockTextPrefiller( + TextDecoderRunner* text_decoder_runner, + bool use_kv_cache, + bool enable_parallel_prefill, + int64_t max_seq_len) + : TextPrefiller( + text_decoder_runner, + use_kv_cache, + enable_parallel_prefill, + max_seq_len) {} + + MOCK_METHOD( + ::executorch::runtime::Result, + prefill_chunk, + (std::vector&, int64_t&), + ()); + }; + + // Create a mock TextPrefiller + std::unique_ptr createMockTextPrefiller( + int64_t max_seq_len, + bool use_kv_cache = true, + bool enable_parallel_prefill = false) { + return std::make_unique( + &text_decoder_runner_, + use_kv_cache, + enable_parallel_prefill, + max_seq_len); + } + + MockTextDecoderRunner text_decoder_runner_; + std::vector return_logits_ = {0.1f, 0.2f, 0.3f, 0.4f}; + TensorFactory tf; + executorch::aten::Tensor tensor = tf.make({1, 4}, return_logits_); +}; + +// Test that prefill() calls prefill_chunk() once when prompt tokens <= +// max_seq_len +TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) { + // Create a spy TextPrefiller with max_seq_len = 10 + auto prefiller = createMockTextPrefiller(10); + + // Create prompt tokens with size <= max_seq_len + std::vector prompt_tokens = {1, 2, 3, 4, 5}; + int64_t start_pos = 0; + + // Expect prefill_chunk to be called exactly once with the entire prompt + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .Times(1) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + // Verify the tokens passed to prefill_chunk + EXPECT_EQ(tokens.size(), prompt_tokens.size()); + for (size_t i = 0; i < tokens.size(); i++) { + EXPECT_EQ(tokens[i], prompt_tokens[i]); + } + // Verify the position + EXPECT_EQ(pos, start_pos); + return Result(42); + }); + + // Call prefill + auto result = prefiller->prefill(prompt_tokens, start_pos); + + // Verify the result + EXPECT_EQ(result.error(), Error::Ok); + EXPECT_EQ(result.get(), 42); +} + +// Test that prefill() calls prefill_chunk() multiple times when prompt tokens > +// max_seq_len +TEST_F( + TextPrefillerTest, + PrefillCallsPrefillChunkMultipleTimesWhenPromptExceedsMaxLen) { + // Create a spy TextPrefiller with max_seq_len = 3 + const int64_t max_seq_len = 3; + auto prefiller = createMockTextPrefiller(max_seq_len); + + // Create prompt tokens with size > max_seq_len + std::vector prompt_tokens = {1, 2, 3, 4, 5, 6, 7, 8}; + int64_t start_pos = 0; + + // Set up expectations for prefill_chunk calls + { + InSequence seq; // Ensure calls happen in the expected order + + // First chunk: tokens [1, 2, 3] + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_EQ(tokens.size(), 3); + EXPECT_EQ(tokens[0], 1); + EXPECT_EQ(tokens[1], 2); + EXPECT_EQ(tokens[2], 3); + EXPECT_EQ(pos, 0); + return Result(10); + }); + + // Second chunk: tokens [4, 5, 6] + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_EQ(tokens.size(), 3); + EXPECT_EQ(tokens[0], 4); + EXPECT_EQ(tokens[1], 5); + EXPECT_EQ(tokens[2], 6); + EXPECT_EQ(pos, 3); + return Result(20); + }); + + // Third chunk: tokens [7, 8] + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_EQ(tokens.size(), 2); + EXPECT_EQ(tokens[0], 7); + EXPECT_EQ(tokens[1], 8); + EXPECT_EQ(pos, 6); + return Result(30); + }); + } + + // Call prefill + auto result = prefiller->prefill(prompt_tokens, start_pos); + + // Verify the result + EXPECT_EQ(result.error(), Error::Ok); + EXPECT_EQ(result.get(), 30); // Should return the token from the last chunk + + // Verify that start_pos has been updated correctly + EXPECT_EQ(start_pos, prompt_tokens.size()); +} + +// Test that prefill() handles edge cases correctly +TEST_F(TextPrefillerTest, PrefillHandlesEdgeCasesCorrectly) { + // Create a spy TextPrefiller with max_seq_len = 1 + const int64_t max_seq_len = 1; + auto prefiller = createMockTextPrefiller(max_seq_len); + + // Create prompt tokens with size > max_seq_len + std::vector prompt_tokens = {1, 2, 3}; + int64_t start_pos = 5; // Non-zero starting position + + // Set up expectations for prefill_chunk calls + { + InSequence seq; + + // First chunk: token [1] + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_EQ(tokens.size(), 1); + EXPECT_EQ(tokens[0], 1); + EXPECT_EQ(pos, 5); + return Result(10); + }); + + // Second chunk: token [2] + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_EQ(tokens.size(), 1); + EXPECT_EQ(tokens[0], 2); + EXPECT_EQ(pos, 6); + return Result(20); + }); + + // Third chunk: token [3] + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_EQ(tokens.size(), 1); + EXPECT_EQ(tokens[0], 3); + EXPECT_EQ(pos, 7); + return Result(30); + }); + } + + // Call prefill + auto result = prefiller->prefill(prompt_tokens, start_pos); + + // Verify the result + EXPECT_EQ(result.error(), Error::Ok); + EXPECT_EQ(result.get(), 30); + + // Verify that start_pos has been updated correctly + EXPECT_EQ(start_pos, 8); // 5 (initial) + 3 (tokens) +} + +// Test that prefill() handles errors from prefill_chunk correctly +TEST_F(TextPrefillerTest, PrefillHandlesPrefillChunkErrorsCorrectly) { + // Create a spy TextPrefiller with max_seq_len = 3 + const int64_t max_seq_len = 3; + auto prefiller = createMockTextPrefiller(max_seq_len); + + // Create prompt tokens with size > max_seq_len + std::vector prompt_tokens = {1, 2, 3, 4, 5}; + int64_t start_pos = 0; + + // Set up expectations for prefill_chunk calls + { + InSequence seq; + + // First chunk: tokens [1, 2, 3] - succeeds + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + return Result(10); + }); + + // Second chunk: tokens [4, 5] - fails + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + return Result(Error::InvalidArgument); + }); + } + + // Call prefill + auto result = prefiller->prefill(prompt_tokens, start_pos); + + // Verify that the error is propagated + EXPECT_EQ(result.error(), Error::InvalidArgument); +} + +// Test that prefill_chunk() works correctly with parallel prefill enabled +TEST_F(TextPrefillerTest, PrefillChunkWorksWithParallelPrefill) { + // Create a TextPrefiller with parallel prefill enabled + auto prefiller = createTextPrefiller(10, true, true); + + // Set up expectations for the text decoder runner + EXPECT_CALL(text_decoder_runner_, step(_, _)) + .Times(1) + .WillOnce(Return(Result(tensor))); + + // Create prompt tokens + std::vector prompt_tokens = {1, 2, 3}; + int64_t start_pos = 0; + + // Call prefill + auto result = prefiller->prefill(prompt_tokens, start_pos); + + // Verify the result + EXPECT_EQ(result.error(), Error::Ok); + + // Verify that start_pos has been updated correctly + EXPECT_EQ(start_pos, prompt_tokens.size()); +} diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 9fa20d2646e..6a0cfd45044 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include namespace executorch::extension::llm { @@ -73,8 +74,9 @@ Error TextLLMRunner::load() { ET_LOG(Info, format, __VA_ARGS__); \ } -Error TextLLMRunner::generate( +Error TextLLMRunner::generate_from_pos( const std::string& prompt, + int64_t start_pos, const GenerationConfig& config, std::function token_callback, std::function stats_callback) { @@ -115,8 +117,8 @@ Error TextLLMRunner::generate( ::tokenizers::Result> encode_res = tokenizer_->encode( prompt, - /* bos */ 0, - /* eos */ 0); + /*bos=*/config.num_bos, + /*eos=*/config.num_eos); ET_CHECK_TK_OK_OR_RETURN_ERROR( encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); @@ -125,20 +127,38 @@ Error TextLLMRunner::generate( std::vector prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); - ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token"); - ET_CHECK_MSG( - num_prompt_tokens < metadata_.at(kMaxContextLen), - "num_prompt_tokens %d >= max_seq_len_ %" PRId64 + // Reduce max_context_len by start_pos + int64_t max_context_len = metadata_.at(kMaxContextLen) - start_pos; + ET_CHECK_OR_RETURN_ERROR( + num_prompt_tokens >= 1, + InvalidArgument, + "Expected at least 1 prompt token"); + ET_CHECK_OR_RETURN_ERROR( + num_prompt_tokens < max_context_len, + InvalidArgument, + "num_prompt_tokens %d >= max_context_len %" PRId64 ", Max seq length exceeded - please increase max seq len value in your export script", num_prompt_tokens, - metadata_.at(kMaxContextLen)); - - // Determine max_new_tokens using the GenerationConfig's resolve method - int max_new_tokens = config.resolve_max_new_tokens( - metadata_.at(kMaxContextLen), num_prompt_tokens); - - ET_LOG(Info, "Max new tokens resolved: %d", max_new_tokens); - + max_context_len); + + // Determine max_new_tokens using the GenerationConfig's resolve method, + // then subtract start_pos for max_new_tokens. + int max_new_tokens = + config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); + + ET_LOG( + Info, + "Max new tokens resolved: %d, given start_pos %" PRId64 + ", num_prompt_tokens %zu, max_context_len %" PRId64, + max_new_tokens, + start_pos, + prompt_tokens.size(), + max_context_len); + ET_CHECK_OR_RETURN_ERROR( + max_new_tokens > 0, + InvalidArgument, + "Max new tokens %d is less than or equal to 0", + max_new_tokens); // Prefill first // Here feed all tokens to the model and get the next predicted token // after the prompt. After that we will enter generate loop. @@ -147,7 +167,7 @@ Error TextLLMRunner::generate( if (config.echo) { wrapped_callback(prompt); } - int64_t pos = 0; + int64_t pos = start_pos; auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); uint64_t cur_token = prefill_res.get(); @@ -201,6 +221,13 @@ Error TextLLMRunner::generate( return Error::Ok; } +Error TextLLMRunner::generate( + const std::string& prompt, + const GenerationConfig& config, + std::function token_callback, + std::function stats_callback) { + return generate_from_pos(prompt, 0, config, token_callback, stats_callback); +} Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { // Create a GenerationConfig for warmup @@ -252,6 +279,12 @@ std::unique_ptr load_tokenizer( return tiktoken_tokenizer; } + auto sp_tokenizer = std::make_unique<::tokenizers::SPTokenizer>(); + if (sp_tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) { + ET_LOG(Info, "Loaded Sentencepiece tokenizer"); + return sp_tokenizer; + } + auto bpe_tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>(); if (bpe_tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) { ET_LOG(Info, "Loaded BPE tokenizer"); diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index 715688ba82c..600d21a8801 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -78,8 +78,9 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { * @brief Generates text based on the provided prompt * * This method performs text generation using the loaded model. It processes - * the input prompt, runs the model in prefill and decode phases, and returns - * generated text through callbacks. + * the input prompt, runs the model in prefill and decode phases until max + * tokens to generate is reached or eos token is generated, then returns + * generated text and perf stats through callbacks. * * @param prompt The input text to generate from * @param config Configuration parameters for text generation (e.g., @@ -94,6 +95,31 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { const GenerationConfig& config, std::function token_callback = {}, std::function stats_callback = {}) override; + + /** + * @brief Generates text based on the provided prompt and start position + * + * This method performs text generation using the loaded model. It processes + * the input prompt, runs the model in prefill and decode phases using the + * start position until max tokens to generate is reached or eos token is + * generated, then returns generated text and perf stats through callbacks. + * + * @param prompt The input text to generate from + * @param start_pos The starting position in KV cache of the input + * @param config Configuration parameters for text generation (e.g., + * max_new_tokens, temperature) + * @param token_callback Function called for each generated token with the + * decoded text + * @param stats_callback Function called with performance statistics + * @return ::executorch::runtime::Error Success or error status + */ + ::executorch::runtime::Error generate_from_pos( + const std::string& prompt, + int64_t start_pos, + const GenerationConfig& config, + std::function token_callback = {}, + std::function stats_callback = {}) override; + /** * @brief Warms up the model with a sample prompt * diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index 19c260f5be6..64f3fee167b 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -24,8 +24,7 @@ TextPrefiller::TextPrefiller( : text_decoder_runner_(text_decoder_runner), use_kv_cache_(use_kv_cache), enable_parallel_prefill_(enable_parallel_prefill), - max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127) { -} // -1 because for some reason tracing results in this upperbound + max_seq_len_(max_seq_len > 0 ? max_seq_len : 128) {} ::executorch::runtime::Result TextPrefiller::prefill( std::vector& prompt_tokens, @@ -56,21 +55,22 @@ ::executorch::runtime::Result TextPrefiller::prefill( prompt_tokens_to_process.begin()); // Process this chunk - auto chunk_result = prefillChunk(prompt_tokens_to_process, start_pos); + auto chunk_result = prefill_chunk(prompt_tokens_to_process, start_pos); ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error()); cur_token = chunk_result.get(); + start_pos += num_tokens_to_prefill_with; num_tokens_to_process += num_tokens_to_prefill_with; } return cur_token; } else { // If prompt tokens don't exceed max_seq_len_, process them directly - return prefillChunk(prompt_tokens, start_pos); + return prefill_chunk(prompt_tokens, start_pos); } } -::executorch::runtime::Result TextPrefiller::prefillChunk( +::executorch::runtime::Result TextPrefiller::prefill_chunk( std::vector& prompt_tokens, int64_t& start_pos) { // enable_parallel_prefill_ maybe set even when not using kv cache diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index 49b2c867167..ce12506a05c 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -45,7 +45,7 @@ class ET_EXPERIMENTAL TextPrefiller { * Module. * @return The next token of the LLM Module after prefilling this chunk. */ - ::executorch::runtime::Result prefillChunk( + virtual ::executorch::runtime::Result prefill_chunk( std::vector& prompt_tokens, int64_t& start_pos); diff --git a/extension/llm/tokenizers b/extension/llm/tokenizers index fc320288580..ffd2973e887 160000 --- a/extension/llm/tokenizers +++ b/extension/llm/tokenizers @@ -1 +1 @@ -Subproject commit fc32028858020c4fcafe37aaaeaf5d1b480336a2 +Subproject commit ffd2973e8879f64c78f01a3f4aa0f77bdc5a1abe diff --git a/install_requirements.py b/install_requirements.py index 7d923672009..66768426a99 100644 --- a/install_requirements.py +++ b/install_requirements.py @@ -75,6 +75,15 @@ def python_is_compatible(): def install_requirements(use_pytorch_nightly): + # Skip pip install on Intel macOS if using nightly. + if use_pytorch_nightly and is_intel_mac_os(): + print( + "ERROR: Prebuilt PyTorch wheels are no longer available for Intel-based macOS.\n" + "Please build from source by following https://docs.pytorch.org/executorch/main/using-executorch-building-from-source.html", + file=sys.stderr, + ) + sys.exit(1) + # pip packages needed by exir. TORCH_PACKAGE = [ # Setting use_pytorch_nightly to false to test the pinned PyTorch commit. Note @@ -163,6 +172,17 @@ def install_optional_example_requirements(use_pytorch_nightly): ) +# Prebuilt binaries for Intel-based macOS are no longer available on PyPI; users must compile from source. +# PyTorch stopped building macOS x86_64 binaries since version 2.3.0 (January 2024). +def is_intel_mac_os(): + # Returns True if running on Intel macOS. + return platform.system().lower() == "darwin" and platform.machine().lower() in ( + "x86", + "x86_64", + "i386", + ) + + def main(args): parser = argparse.ArgumentParser() parser.add_argument( diff --git a/kernels/optimized/lib_defs.bzl b/kernels/optimized/lib_defs.bzl index 7c8c1119ec7..b1dca5dff84 100644 --- a/kernels/optimized/lib_defs.bzl +++ b/kernels/optimized/lib_defs.bzl @@ -111,10 +111,6 @@ def get_preprocessor_flags(): return preprocessor_flags -# Currently, having a dependency on fbsource//third-party/sleef:sleef may cause -# duplicate symbol errors when linking fbcode targets in opt mode that also -# depend on ATen. This is because ATen accesses sleef via the third-party folder -# in caffe2 (caffe2/third-party//sleef:sleef). # TODO(ssjia): Enable -DCPU_CAPABILITY_AVX2 in fbcode, which requires sleef. def define_libs(is_fbcode=False): runtime.cxx_library( diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index c1f2770d3d6..876099598dc 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -288,16 +288,16 @@ Tensor& dequantize_per_tensor_out( static_cast(scale)); \ } \ } break; -#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ +#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOATH_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ break; switch (input.scalar_type()) { @@ -459,7 +459,8 @@ Tensor& dequantize_per_channel_out( } \ out_data_ptr[current_ix] = \ static_cast( \ - input_data_ptr[current_ix] - zero_point) * \ + input_data_ptr[current_ix] - \ + static_cast(zero_point)) * \ _scale; \ } \ }, \ @@ -478,23 +479,24 @@ Tensor& dequantize_per_channel_out( apply_over_dim_list( \ [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ out_data_ptr[in_ix] = static_cast( \ - (input_data_ptr[in_ix] - _zero_point) * _scale); \ + (input_data_ptr[in_ix] - static_cast(_zero_point)) * \ + _scale); \ }, \ input, \ optional_dim_list, \ channel_ix); \ } \ break; -#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ +#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOATH_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ break; switch (input.scalar_type()) { diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index 4665c3d665b..d0b7c882f8e 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -150,7 +150,7 @@ Tensor& quantize_per_tensor_out( break; switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE); + ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE); default: ET_CHECK_MSG( false, @@ -346,7 +346,7 @@ Tensor& quantize_per_channel_out( break; switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE); + ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE); default: ET_CHECK_MSG( false, diff --git a/kernels/quantized/test/op_dequantize_test.cpp b/kernels/quantized/test/op_dequantize_test.cpp index bbda1590a10..4a0c195e3ab 100644 --- a/kernels/quantized/test/op_dequantize_test.cpp +++ b/kernels/quantized/test/op_dequantize_test.cpp @@ -67,6 +67,96 @@ TEST(OpDequantizeOutTest, AllDtypesSupported) { test_dtype(); } +/// Test all supported output dtypes for dequantization +template +void test_output_dtype() { + TensorFactory tf; + + Tensor input = tf.full({3, 5}, 100); + double scale = 0.5; + int64_t zero_point = 30; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + // (100 - 30) * 0.5 = 35 + Tensor expected = tfo.full({3, 5}, 35); + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(OUT_DTYPE), + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpDequantizeOutTest, AllOutputDtypesSupported) { + et_pal_init(); + test_output_dtype(); + test_output_dtype(); + test_output_dtype(); +} + +TEST(OpDequantizeOutTest, HalfOutput) { + et_pal_init(); + TensorFactory tf; + + Tensor input = tf.full({3, 5}, 10); + double scale = 0.5; + int64_t zero_point = 100000; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + // (10 - 100000) * 0.5 = -49995 + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(ScalarType::Half), + out); + + // The expected result should be (10 - 100000) * 0.5 = -49995 + Tensor expected = tfo.full({3, 5}, -49995); + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpDequantizeOutTest, DoubleOutput) { + et_pal_init(); + TensorFactory tf; + + Tensor input = tf.full({3, 5}, 10); + double scale = 0.5; + int64_t zero_point = 100000; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(ScalarType::Double), + out); + + // The expected result should be (10 - 100000) * 0.5 = -49995 + Tensor expected = tfo.full({3, 5}, -49995); + EXPECT_TENSOR_EQ(out, expected); +} + TEST(OpDequantizeOutTest, NonWholeNumbers) { et_pal_init(); TensorFactory tf; diff --git a/kernels/quantized/test/op_quantize_test.cpp b/kernels/quantized/test/op_quantize_test.cpp index 704d8d06c5c..5cd17223d80 100644 --- a/kernels/quantized/test/op_quantize_test.cpp +++ b/kernels/quantized/test/op_quantize_test.cpp @@ -49,6 +49,32 @@ void test_dtype() { EXPECT_TENSOR_EQ(out, expected); } +template +void test_input_dtype() { + TensorFactory tf_input; + + Tensor input = tf_input.full({3, 5}, 4); + double scale = 0.5; + int64_t zero_point = 108; + int64_t quant_min = 0; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + // 4 / 0.5 + 108 = 116 + Tensor expected = tfo.full({3, 5}, 116); + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, AllInputDtypesSupported) { + test_input_dtype(); + test_input_dtype(); + test_input_dtype(); +} + TEST(OpQuantizeOutTest, AllDtypesSupported) { test_dtype(); test_dtype(); @@ -58,6 +84,45 @@ TEST(OpQuantizeOutTest, AllDtypesSupported) { test_dtype(); } +TEST(OpQuantizeOutTest, DoubleInputTest) { + TensorFactory tf_double; + + // Test with a more complex value that might have precision differences + Tensor input = tf_double.full({2, 3}, 3.14159265359); + double scale = 0.01; + int64_t zero_point = -100; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // 3.14159265359 / 0.01 - 100 = 214.159265359 + Tensor expected = tfo.full({2, 3}, 214); + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, HalfInputTest) { + TensorFactory tf_half; + + Tensor input = tf_half.full({2, 3}, 2.5); + double scale = 0.5; + int64_t zero_point = 10; + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // 2.5 / 0.5 + 10 = 15 + Tensor expected = tfo.full({2, 3}, 15); + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + TEST(OpQuantizeOutTest, TensorArgOverload) { TensorFactory tf_float; TensorFactory tf_double; diff --git a/requirements-dev.txt b/requirements-dev.txt index a4ed212fb65..07c63101eb8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,3 +9,5 @@ wheel # For building the pip package archive. zstd # Imported by resolve_buck.py. lintrunner==0.12.7 lintrunner-adapters==0.12.4 +hydra-core>=1.3.0 +omegaconf>=2.3.0 diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 6f81146e925..d81b3ad4d0f 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -199,6 +199,11 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) _(ANOTHER_INPUT, float, Float) \ _(ANOTHER_INPUT, double, Double) +#define ET_FORALL_FLOATH_TYPES_WITH(ANOTHER_INPUT, _) \ + _(ANOTHER_INPUT, float, Float) \ + _(ANOTHER_INPUT, double, Double) \ + _(ANOTHER_INPUT, ::executorch::aten::Half, Half) + #define ET_FORALL_FLOAT_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) diff --git a/runtime/core/portable_type/c10/c10/targets.bzl b/runtime/core/portable_type/c10/c10/targets.bzl index 827a63d2cef..cb41bd0bb8e 100644 --- a/runtime/core/portable_type/c10/c10/targets.bzl +++ b/runtime/core/portable_type/c10/c10/targets.bzl @@ -1,10 +1,47 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime", "is_arvr_mode") -def get_sleef_preprocessor_flags(): +def get_preprocessor_flags(is_fbcode): + flags = ["-DSTANDALONE_TORCH_HEADER"] if runtime.is_oss: - return [] - return ["-DAT_BUILD_ARM_VEC256_WITH_SLEEF"] + return flags + arm64_flags = [ + "-DCPU_CAPABILITY_DEFAULT", + ] + if is_fbcode: + # TODO: enable Sleef in xplat? + arm64_flags = arm64_flags + ["-DAT_BUILD_ARM_VEC256_WITH_SLEEF"] + x86_avx2_flags = [ + "-DCPU_CAPABILITY_AVX2", + "-DHAVE_AVX2_CPU_DEFINITION", + ] + default_flags = [ + "-DCPU_CAPABILITY_DEFAULT", + ] + fbcode_flags = select({ + "ovr_config//cpu:x86_64": x86_avx2_flags, + "ovr_config//cpu:arm64": arm64_flags, + "DEFAULT": default_flags, + }) + non_fbcode_flags = select({ + "ovr_config//cpu/x86:avx2": x86_avx2_flags, + "ovr_config//cpu:arm64": arm64_flags, + "DEFAULT": default_flags, + }) + return flags + ["-DET_USE_PYTORCH_HEADERS"] + (fbcode_flags if is_fbcode else non_fbcode_flags) + +def get_sleef_deps(): + if runtime.is_oss: + return [] + return select({ + "DEFAULT": [], + "ovr_config//cpu:x86_64": [ + "fbsource//third-party/sleef:sleef", + ], + "ovr_config//cpu:arm64": [ + "fbsource//third-party/sleef:sleef", + ], + }) def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. @@ -12,84 +49,83 @@ def define_common_targets(): The directory containing this targets.bzl file should also contain both TARGETS and BUCK files that call this function. """ - runtime.cxx_library( - name = "c10", - header_namespace = "c10", - exported_headers = [ - "macros/Export.h", - "macros/Macros.h", - "util/BFloat16.h", - "util/BFloat16-inl.h", - "util/BFloat16-math.h", - "util/Half.h", - "util/Half-inl.h", - "util/TypeSafeSignMath.h", - "util/bit_cast.h", - "util/complex.h", - "util/complex_math.h", - "util/complex_utils.h", - "util/floating_point_utils.h", - "util/irange.h", - ], - exported_preprocessor_flags = [ - "-DC10_USING_CUSTOM_GENERATED_MACROS", - ] + ([] if runtime.is_oss else [ - "-DC10_USE_GLOG", - "-DC10_USE_MINIMAL_GLOG", - ]), - visibility = [ - "//executorch/...", - "@EXECUTORCH_CLIENTS", - ], - deps = select({ - "DEFAULT": [], - # Half-inl.h depends on vec_half.h from ATen, but only when building for x86. - "ovr_config//cpu:x86_64": [ - ":aten_headers_for_executorch", - ], - }), - ) runtime.cxx_library( name = "aten_headers_for_executorch", srcs = [], visibility = ["//executorch/kernels/optimized/...", "@EXECUTORCH_CLIENTS"], + # select() on ovr_config//runtime:fbcode does not work + # properly in all cases. I have seen + # //xplat/executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch + # pass such a select in (at least) arvr mode. Going back to + # fbcode_exported_deps accordingly. exported_deps = select({ "DEFAULT": [], "ovr_config//cpu:arm64": [ "fbsource//third-party/sleef:sleef", ] if not runtime.is_oss else [], - # fbsource//third-party/sleef:sleef currently fails to - # link with missing symbols, hence the fbcode-specific dep below. }), + xplat_exported_deps = [ + "//xplat/caffe2:aten_header", + "//xplat/caffe2/c10:c10_headers", + ("//xplat/caffe2:ovrsource_aten_Config.h" + if is_arvr_mode() else "//xplat/caffe2:generated_aten_config_header"), + ], # + get_sleef_deps(), # TODO: enable Sleef in xplat? fbcode_exported_deps = ([ "//caffe2:aten-headers-cpu", "//caffe2:generated-config-header", "//caffe2/c10:c10_headers", - ] + select({ - "DEFAULT": [], - "ovr_config//cpu:x86_64": [ - "third-party//sleef:sleef", - ] - })) if not runtime.is_oss else [], - fbcode_exported_preprocessor_flags = [ - # We don't -DCPU_CAPABILITY=AVX2 because that trips - # -Wmacro-redefined, and we only care about getting - # reasonable vectorization and Sleef support. - "-DCPU_CAPABILITY_AVX2", - "-DET_USE_PYTORCH_HEADERS", - "-DHAVE_AVX2_CPU_DEFINITION", - "-DSTANDALONE_TORCH_HEADER", - ] + get_sleef_preprocessor_flags(), - xplat_exported_deps = [ - "//xplat/caffe2:aten_header", - "//xplat/caffe2/c10:c10_headers", - ] + ["//xplat/caffe2:ovrsource_aten_Config.h" if is_arvr_mode() else "//xplat/caffe2:generated_aten_config_header",], - exported_preprocessor_flags = select({ - # Intentionally punting on non-fbcode x86 sleef support - # for now because of fbsource//third-party/sleef:sleef - # linker failure. - "ovr_config//cpu:arm64": get_sleef_preprocessor_flags(), - "DEFAULT": [], - }) + ["-DSTANDALONE_TORCH_HEADER"] + ([] if runtime.is_oss else ["-DET_USE_PYTORCH_HEADERS"]), + ] + get_sleef_deps()) if not runtime.is_oss else [], + exported_preprocessor_flags = get_preprocessor_flags(is_fbcode=False) + + ([] if runtime.is_oss else ["-DET_USE_PYTORCH_HEADERS"]), + fbcode_exported_preprocessor_flags = get_preprocessor_flags(is_fbcode=True) + + ([] if runtime.is_oss else ["-DET_USE_PYTORCH_HEADERS"]), ) + + if runtime.is_oss: + runtime.cxx_library( + name = "c10", + header_namespace = "c10", + exported_headers = [ + "macros/Export.h", + "macros/Macros.h", + "util/BFloat16.h", + "util/BFloat16-inl.h", + "util/BFloat16-math.h", + "util/Half.h", + "util/Half-inl.h", + "util/TypeSafeSignMath.h", + "util/bit_cast.h", + "util/complex.h", + "util/complex_math.h", + "util/complex_utils.h", + "util/floating_point_utils.h", + "util/irange.h", + ], + exported_preprocessor_flags = [ + "-DC10_USING_CUSTOM_GENERATED_MACROS", + ] + ([] if runtime.is_oss else [ + "-DC10_USE_GLOG", + "-DC10_USE_MINIMAL_GLOG", + ]), + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = select({ + "DEFAULT": [], + # Half-inl.h depends on vec_half.h from ATen, but only when building for x86. + "ovr_config//cpu:x86_64": [ + ":aten_headers_for_executorch", + ], + }), + ) + else: + runtime.cxx_library( + name = "c10", + exported_deps = [":aten_headers_for_executorch"], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + ) diff --git a/runtime/core/span.h b/runtime/core/span.h index 1bcde396ccd..2202204571a 100644 --- a/runtime/core/span.h +++ b/runtime/core/span.h @@ -55,6 +55,10 @@ class Span final { template /* implicit */ constexpr Span(T (&Arr)[N]) : data_(Arr), length_(N) {} + /// Construct a Span from a single element reference. + /* implicit */ constexpr Span(T& single_element) + : data_(&single_element), length_(1) {} + /// @returns a pointer to the start of the underlying element buffer. iterator begin() const noexcept { return data_; diff --git a/runtime/core/test/span_test.cpp b/runtime/core/test/span_test.cpp index c2d65baf8e7..51063b8be0f 100644 --- a/runtime/core/test/span_test.cpp +++ b/runtime/core/test/span_test.cpp @@ -61,3 +61,19 @@ TEST(SpanTest, TriviallyCopyable) { EXPECT_EQ(span.size(), span_copy.size()); EXPECT_TRUE(std::is_trivially_copyable>::value); } + +TEST(SpanTest, SingleElementConstructor) { + int64_t single_value = 42; + Span span = single_value; + + EXPECT_EQ(span.size(), 1); + EXPECT_EQ(span.data(), &single_value); + EXPECT_EQ(span[0], 42); + EXPECT_EQ(*span.begin(), 42); + EXPECT_EQ(span.end(), span.begin() + 1); + + // Test that modifying through span affects original value + span[0] = 100; + EXPECT_EQ(single_value, 100); + EXPECT_EQ(span[0], 100); +} diff --git a/runtime/platform/compiler.h b/runtime/platform/compiler.h index 864b76e2050..f8588930e15 100644 --- a/runtime/platform/compiler.h +++ b/runtime/platform/compiler.h @@ -173,7 +173,7 @@ using ssize_t = ptrdiff_t; #ifdef __EXCEPTIONS #define ET_HAS_EXCEPTIONS 1 -#elif defined(_HAS_EXCEPTIONS) && _HAS_EXCEPTIONS +#elif defined(_MSC_VER) && defined(_HAS_EXCEPTIONS) && _HAS_EXCEPTIONS #define ET_HAS_EXCEPTIONS 1 #else #define ET_HAS_EXCEPTIONS 0 diff --git a/shim_et/xplat/executorch/kernels/optimized/lib_defs.bzl b/shim_et/xplat/executorch/kernels/optimized/lib_defs.bzl index a940a7114cf..d2feda36407 100644 --- a/shim_et/xplat/executorch/kernels/optimized/lib_defs.bzl +++ b/shim_et/xplat/executorch/kernels/optimized/lib_defs.bzl @@ -69,15 +69,14 @@ def get_vec_cxx_preprocessor_flags(): return preprocessor_flags def get_vec_fbcode_preprocessor_flags(): - preprocessor_flags = [ - "-DCPU_CAPABILITY_AVX2", - ] + preprocessor_flags = select({ + "ovr_config//cpu/x86:avx2": [ + "-DCPU_CAPABILITY_AVX2", + ], + "DEFAULT": [], + }) return preprocessor_flags -# Currently, having a dependency on fbsource//third-party/sleef:sleef may cause -# duplicate symbol errors when linking fbcode targets in opt mode that also -# depend on ATen. This is because ATen accesses sleef via the third-party folder -# in caffe2 (caffe2/third-party//sleef:sleef). # TODO(ssjia): Enable -DCPU_CAPABILITY_AVX2 in fbcode, which requires sleef. def define_libs(): runtime.cxx_library( diff --git a/tools/cmake/Codegen.cmake b/tools/cmake/Codegen.cmake index 2fc4c2675a5..93331c7ed89 100644 --- a/tools/cmake/Codegen.cmake +++ b/tools/cmake/Codegen.cmake @@ -22,15 +22,20 @@ function(gen_selected_ops) message(STATUS " INCLUDE_ALL_OPS: ${GEN_INCLUDE_ALL_OPS}") message(STATUS " OPS_FROM_MODEL: ${GEN_OPS_FROM_MODEL}") message(STATUS " DTYPE_SELECTIVE_BUILD: ${GEN_DTYPE_SELECTIVE_BUILD}") + + set(_out_dir ${CMAKE_CURRENT_BINARY_DIR}/${GEN_LIB_NAME}) + if(GEN_DTYPE_SELECTIVE_BUILD) - message(STATUS " DTYPE_SELECTIVE_BUILD is still WIP and may not be fully functional") + if(NOT GEN_OPS_FROM_MODEL) + message(FATAL_ERROR " DTYPE_SELECTIVE_BUILD is only support with model API, please pass in a model") + endif() endif() set(_oplist_yaml - ${CMAKE_CURRENT_BINARY_DIR}/${GEN_LIB_NAME}/selected_operators.yaml + ${_out_dir}/selected_operators.yaml ) - file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/${GEN_LIB_NAME}) + file(MAKE_DIRECTORY ${_out_dir}) file(GLOB_RECURSE _codegen_tools_srcs "${EXECUTORCH_ROOT}/codegen/tools/*.py") @@ -64,18 +69,18 @@ function(gen_selected_ops) if(GEN_DTYPE_SELECTIVE_BUILD) set(_opvariant_h - ${CMAKE_CURRENT_BINARY_DIR}/${GEN_LIB_NAME}/selected_op_variants.h + ${_out_dir}/selected_op_variants.h ) set(_gen_opvariant_command "${PYTHON_EXECUTABLE}" -m codegen.tools.gen_selected_op_variants --yaml-file=${_oplist_yaml} - --output-dir=${CMAKE_CURRENT_BINARY_DIR}/${GEN_LIB_NAME}/ + --output-dir=${_out_dir}/ ) message("Command - ${_gen_opvariant_command}") add_custom_command( - COMMENT "Generating selected_op_variants.h for ${GEN_LIB_NAME}" + COMMENT "Generating ${_opvariant_h} for ${GEN_LIB_NAME}" OUTPUT ${_opvariant_h} COMMAND ${_gen_opvariant_command} - DEPENDS ${_oplist_yaml} ${_codegen_tools_srcs} + DEPENDS ${_oplist_yaml} ${GEN_OPS_SCHEMA_YAML} ${_codegen_tools_srcs} WORKING_DIRECTORY ${EXECUTORCH_ROOT} ) endif() @@ -88,7 +93,7 @@ endfunction() # functions_yaml CUSTOM_OPS_YAML custom_ops_yaml ) function(generate_bindings_for_kernels) set(options ADD_EXCEPTION_BOUNDARY) - set(arg_names LIB_NAME FUNCTIONS_YAML CUSTOM_OPS_YAML) + set(arg_names LIB_NAME FUNCTIONS_YAML CUSTOM_OPS_YAML DTYPE_SELECTIVE_BUILD) cmake_parse_arguments(GEN "${options}" "${arg_names}" "" ${ARGN}) message(STATUS "Generating kernel bindings:") @@ -96,6 +101,7 @@ function(generate_bindings_for_kernels) message(STATUS " FUNCTIONS_YAML: ${GEN_FUNCTIONS_YAML}") message(STATUS " CUSTOM_OPS_YAML: ${GEN_CUSTOM_OPS_YAML}") message(STATUS " ADD_EXCEPTION_BOUNDARY: ${GEN_ADD_EXCEPTION_BOUNDARY}") + message(STATUS " DTYPE_SELECTIVE_BUILD: ${GEN_DTYPE_SELECTIVE_BUILD}") # Command to generate selected_operators.yaml from custom_ops.yaml. file(GLOB_RECURSE _codegen_templates "${EXECUTORCH_ROOT}/codegen/templates/*") @@ -104,6 +110,13 @@ function(generate_bindings_for_kernels) # By default selective build output is selected_operators.yaml set(_oplist_yaml ${_out_dir}/selected_operators.yaml) + # If dtype selective build is enable, force header file to be preserved + if(GEN_DTYPE_SELECTIVE_BUILD) + set(_opvariant_h ${_out_dir}/selected_op_variants.h) + else() + set(_opvariant_h "") + endif() + # Command to codegen C++ wrappers to register custom ops to both PyTorch and # Executorch runtime. execute_process( @@ -148,8 +161,9 @@ function(generate_bindings_for_kernels) COMMENT "Generating code for kernel registration" OUTPUT ${_gen_command_sources} COMMAND ${_gen_command} - DEPENDS ${_oplist_yaml} ${GEN_CUSTOM_OPS_YAML} ${GEN_FUNCTIONS_YAML} - ${_codegen_templates} ${_torchgen_srcs} + DEPENDS ${_oplist_yaml} ${_opvariant_h} ${GEN_CUSTOM_OPS_YAML} + ${GEN_FUNCTIONS_YAML} ${_codegen_templates} + ${_torchgen_srcs} WORKING_DIRECTORY ${EXECUTORCH_ROOT} ) # Make generated file list available in parent scope @@ -191,29 +205,85 @@ endfunction() # Generate a runtime lib for registering operators in Executorch function(gen_operators_lib) - set(multi_arg_names LIB_NAME KERNEL_LIBS DEPS) + set(multi_arg_names LIB_NAME KERNEL_LIBS DEPS DTYPE_SELECTIVE_BUILD) cmake_parse_arguments(GEN "" "" "${multi_arg_names}" ${ARGN}) message(STATUS "Generating operator lib:") message(STATUS " LIB_NAME: ${GEN_LIB_NAME}") message(STATUS " KERNEL_LIBS: ${GEN_KERNEL_LIBS}") message(STATUS " DEPS: ${GEN_DEPS}") + message(STATUS " DTYPE_SELECTIVE_BUILD: ${GEN_DTYPE_SELECTIVE_BUILD}") set(_out_dir ${CMAKE_CURRENT_BINARY_DIR}/${GEN_LIB_NAME}) + if(GEN_DTYPE_SELECTIVE_BUILD) + set(_opvariant_h + ${_out_dir}/selected_op_variants.h + ) + endif() add_library(${GEN_LIB_NAME}) + + set(_srcs_list + ${_out_dir}/RegisterCodegenUnboxedKernelsEverything.cpp + ${_out_dir}/Functions.h ${_out_dir}/NativeFunctions.h + ) + if(GEN_DTYPE_SELECTIVE_BUILD) + list(APPEND _srcs_list ${_opvariant_h}) + endif() target_sources( ${GEN_LIB_NAME} - PRIVATE ${_out_dir}/RegisterCodegenUnboxedKernelsEverything.cpp - ${_out_dir}/Functions.h ${_out_dir}/NativeFunctions.h + PRIVATE ${_srcs_list} ) target_link_libraries(${GEN_LIB_NAME} PRIVATE ${GEN_DEPS}) + set(portable_kernels_check "portable_kernels") if(GEN_KERNEL_LIBS) - target_link_libraries(${GEN_LIB_NAME} PUBLIC ${GEN_KERNEL_LIBS}) + + set(_common_compile_options -Wno-deprecated-declarations -ffunction-sections -fdata-sections -Os) + + if(GEN_DTYPE_SELECTIVE_BUILD) + if("${portable_kernels_check}" IN_LIST GEN_KERNEL_LIBS) + list(REMOVE_ITEM GEN_KERNEL_LIBS ${portable_kernels_check}) + + # Build kernels_util_all_deps, since later selected_portable_kernels depends on it + list(TRANSFORM _kernels_util_all_deps__srcs PREPEND "${EXECUTORCH_ROOT}/") + add_library(selected_kernels_util_all_deps ${_kernels_util_all_deps__srcs}) + target_link_libraries(selected_kernels_util_all_deps PRIVATE executorch_core) + target_include_directories(selected_kernels_util_all_deps PUBLIC ${_common_include_directories}) + target_compile_definitions(selected_kernels_util_all_deps PUBLIC C10_USING_CUSTOM_GENERATED_MACROS) + target_compile_options(selected_kernels_util_all_deps PUBLIC ${_common_compile_options}) + + # Build selected_portable_kernels + list(TRANSFORM _portable_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/") + add_library(selected_portable_kernels ${_portable_kernels__srcs}) + target_link_libraries(selected_portable_kernels PRIVATE executorch_core selected_kernels_util_all_deps) + target_compile_options(selected_portable_kernels PUBLIC ${_common_compile_options}) + target_include_directories(selected_portable_kernels PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/${GEN_LIB_NAME}/) + + # Make sure the header is generated before compiling the library + add_dependencies(selected_portable_kernels ${GEN_LIB_NAME}) + # Create a custom target for the header to ensure proper dependency tracking + add_custom_target(selected_portable_kernels_header DEPENDS ${_opvariant_h}) + add_dependencies(selected_portable_kernels selected_portable_kernels_header) + # Apply the compile definition for dtype selective build + target_compile_definitions(selected_portable_kernels PRIVATE EXECUTORCH_SELECTIVE_BUILD_DTYPE=1) + + target_link_libraries(${GEN_LIB_NAME} PUBLIC selected_portable_kernels) + else() + message(FATAL_ERROR "Currently dtype selective build is only supported for portable_kernels but {${GEN_KERNEL_LIBS}} were provided!") + endif() + endif() + + # After removing portable_kernels, test if there are other kernel libs provided + if(GEN_KERNEL_LIBS) + target_link_libraries(${GEN_LIB_NAME} PUBLIC ${GEN_KERNEL_LIBS}) + endif() endif() target_link_options_shared_lib(${GEN_LIB_NAME}) set(_generated_headers ${_out_dir}/Functions.h ${_out_dir}/NativeFunctions.h) + if(GEN_DTYPE_SELECTIVE_BUILD) + list(APPEND _generated_headers ${_opvariant_h}) + endif() set_target_properties( ${GEN_LIB_NAME} PROPERTIES PUBLIC_HEADER "${_generated_headers}" )