diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc index 1776d492e..47fc60068 100644 --- a/bin/pytorch_inference/CSupportedOperations.cc +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -39,6 +39,8 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA // elastic/test-elser-v2. // Additional ops from Elasticsearch integration test models // (PyTorchModelIT, TextExpansionQueryIT, TextEmbeddingQueryIT). +// Quantized operations from dynamically quantized variants of the above +// models (torch.quantization.quantize_dynamic on nn.Linear layers). const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATIONS = { // aten operations (core tensor computations) "aten::Int"sv, @@ -79,6 +81,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::mean"sv, "aten::min"sv, "aten::mul"sv, + "aten::mul_"sv, "aten::ne"sv, "aten::neg"sv, "aten::new_ones"sv, @@ -124,6 +127,8 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "prim::dtype"sv, "prim::max"sv, "prim::min"sv, + // quantized operations (dynamically quantized models, e.g. ELSER v2) + "quantized::linear_dynamic"sv, }; } } diff --git a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json index 164ead379..364d49f86 100644 --- a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json +++ b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json @@ -267,6 +267,7 @@ }, "elastic-eis-elser-v2": { "model_id": "elastic/eis-elser-v2", + "quantized": false, "ops": [ "aten::Int", "aten::ScalarImplicit", @@ -303,6 +304,7 @@ }, "elastic-elser-v2": { "model_id": "elastic/elser-v2", + "quantized": false, "ops": [ "aten::Int", "aten::ScalarImplicit", @@ -337,6 +339,44 @@ "prim::NumToTensor" ] }, + "elastic-elser-v2-quantized": { + "model_id": "elastic/elser-v2", + "quantized": true, + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gather", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::mul_", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor", + "quantized::linear_dynamic" + ] + }, "elastic-hugging-face-elser": { "model_id": "elastic/hugging-face-elser", "ops": [ diff --git a/dev-tools/extract_model_ops/extract_model_ops.py b/dev-tools/extract_model_ops/extract_model_ops.py index 676a7ef4b..562590c31 100644 --- a/dev-tools/extract_model_ops/extract_model_ops.py +++ b/dev-tools/extract_model_ops/extract_model_ops.py @@ -35,25 +35,25 @@ import torch -from torchscript_utils import collect_inlined_ops, load_and_trace_hf_model +from torchscript_utils import ( + collect_inlined_ops, + load_and_trace_hf_model, + load_model_config, +) SCRIPT_DIR = Path(__file__).resolve().parent DEFAULT_CONFIG = SCRIPT_DIR / "reference_models.json" -def load_reference_models(config_path: Path) -> dict[str, str]: - """Load the architecture-to-model mapping from a JSON config file.""" - with open(config_path) as f: - return json.load(f) - - -def extract_ops_for_model(model_name: str) -> set[str] | None: +def extract_ops_for_model(model_name: str, + quantize: bool = False) -> set[str] | None: """Trace a HuggingFace model and return its TorchScript op set. Returns None if the model could not be loaded or traced. """ - print(f" Loading {model_name}...", file=sys.stderr) - traced = load_and_trace_hf_model(model_name) + label = f"{model_name} (quantized)" if quantize else model_name + print(f" Loading {label}...", file=sys.stderr) + traced = load_and_trace_hf_model(model_name, quantize=quantize) if traced is None: return None return collect_inlined_ops(traced) @@ -81,7 +81,7 @@ def main(): help="Path to reference_models.json config file") args = parser.parse_args() - reference_models = load_reference_models(args.config) + reference_models = load_model_config(args.config) per_model_ops = {} union_ops = set() @@ -90,8 +90,9 @@ def main(): file=sys.stderr) failed = [] - for arch, model_name in reference_models.items(): - ops = extract_ops_for_model(model_name) + for arch, spec in reference_models.items(): + ops = extract_ops_for_model(spec["model_id"], + quantize=spec["quantized"]) if ops is None: failed.append(arch) print(f" {arch}: FAILED", file=sys.stderr) @@ -109,7 +110,8 @@ def main(): "pytorch_version": torch.__version__, "models": { arch: { - "model_id": reference_models[arch], + "model_id": reference_models[arch]["model_id"], + "quantized": reference_models[arch]["quantized"], "ops": sorted(ops), } for arch, ops in sorted(per_model_ops.items()) @@ -125,7 +127,11 @@ def main(): if args.per_model: for arch, ops in sorted(per_model_ops.items()): - print(f"\n=== {arch} ({reference_models[arch]}) ===") + spec = reference_models[arch] + label = spec["model_id"] + if spec["quantized"]: + label += " (quantized)" + print(f"\n=== {arch} ({label}) ===") for op in sorted(ops): print(f" {op}") diff --git a/dev-tools/extract_model_ops/reference_models.json b/dev-tools/extract_model_ops/reference_models.json index 255762721..76aefef01 100644 --- a/dev-tools/extract_model_ops/reference_models.json +++ b/dev-tools/extract_model_ops/reference_models.json @@ -16,5 +16,10 @@ "elastic-hugging-face-elser": "elastic/hugging-face-elser", "elastic-multilingual-e5-small-optimized": "elastic/multilingual-e5-small-optimized", "elastic-splade-v3": "elastic/splade-v3", - "elastic-test-elser-v2": "elastic/test-elser-v2" + "elastic-test-elser-v2": "elastic/test-elser-v2", + + "_comment:quantized": "Quantized variants: Eland applies torch.quantization.quantize_dynamic on nn.Linear layers when importing models. These produce quantized::* ops not present in the standard traced graphs above.", + "elastic-elser-v2-quantized": {"model_id": "elastic/elser-v2", "quantized": true}, + "elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantized": true}, + "elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true} } diff --git a/dev-tools/extract_model_ops/torchscript_utils.py b/dev-tools/extract_model_ops/torchscript_utils.py index 7ad860b58..33042f261 100644 --- a/dev-tools/extract_model_ops/torchscript_utils.py +++ b/dev-tools/extract_model_ops/torchscript_utils.py @@ -11,13 +11,52 @@ # """Shared utilities for extracting and inspecting TorchScript operations.""" +import json import os import sys +from pathlib import Path import torch from transformers import AutoConfig, AutoModel, AutoTokenizer +def load_model_config(config_path: Path) -> dict[str, dict]: + """Load a model config JSON file and normalise entries. + + Each entry is either a plain model-name string or a dict with + ``model_id`` (required) and optional ``quantized`` boolean. All + entries are normalised to ``{"model_id": str, "quantized": bool}``. + Keys starting with ``_comment`` are silently skipped. + + Raises ``ValueError`` for malformed entries so that config problems + are caught early with an actionable message. + """ + with open(config_path) as f: + raw = json.load(f) + + models: dict[str, dict] = {} + for key, value in raw.items(): + if key.startswith("_comment"): + continue + if isinstance(value, str): + models[key] = {"model_id": value, "quantized": False} + elif isinstance(value, dict): + if "model_id" not in value: + raise ValueError( + f"Config entry {key!r} is a dict but missing required " + f"'model_id' key: {value!r}") + models[key] = { + "model_id": value["model_id"], + "quantized": value.get("quantized", False), + } + else: + raise ValueError( + f"Config entry {key!r} has unsupported type " + f"{type(value).__name__}: {value!r}. " + f"Expected a model name string or a dict with 'model_id'.") + return models + + def collect_graph_ops(graph) -> set[str]: """Collect all operation names from a TorchScript graph, including blocks.""" ops = set() @@ -35,9 +74,13 @@ def collect_inlined_ops(module) -> set[str]: return collect_graph_ops(graph) -def load_and_trace_hf_model(model_name: str): +def load_and_trace_hf_model(model_name: str, quantize: bool = False): """Load a HuggingFace model, tokenize sample input, and trace to TorchScript. + When *quantize* is True the model is dynamically quantized (nn.Linear + layers converted to quantized::linear_dynamic) before tracing. This + mirrors what Eland does when importing models for Elasticsearch. + Returns the traced module, or None if the model could not be loaded or traced. """ token = os.environ.get("HF_TOKEN") @@ -53,6 +96,16 @@ def load_and_trace_hf_model(model_name: str): print(f" LOAD ERROR: {exc}", file=sys.stderr) return None + if quantize: + try: + model = torch.quantization.quantize_dynamic( + model, {torch.nn.Linear}, dtype=torch.qint8) + print(" Applied dynamic quantization (nn.Linear -> qint8)", + file=sys.stderr) + except Exception as exc: + print(f" QUANTIZE ERROR: {exc}", file=sys.stderr) + return None + inputs = tokenizer( "This is a sample input for graph extraction.", return_tensors="pt", padding="max_length", diff --git a/dev-tools/extract_model_ops/validate_allowlist.py b/dev-tools/extract_model_ops/validate_allowlist.py index 5d31d44bf..828749dbc 100644 --- a/dev-tools/extract_model_ops/validate_allowlist.py +++ b/dev-tools/extract_model_ops/validate_allowlist.py @@ -29,7 +29,6 @@ """ import argparse -import json import re import sys from pathlib import Path @@ -40,6 +39,7 @@ collect_graph_ops, collect_inlined_ops, load_and_trace_hf_model, + load_model_config, ) SCRIPT_DIR = Path(__file__).resolve().parent @@ -103,10 +103,12 @@ def check_ops(ops: set[str], def validate_model(model_name: str, allowed: set[str], forbidden: set[str], - verbose: bool) -> bool: + verbose: bool, + quantize: bool = False) -> bool: """Validate one HuggingFace model. Returns True if all ops pass.""" - print(f" {model_name}...", file=sys.stderr) - traced = load_and_trace_hf_model(model_name) + label = f"{model_name} (quantized)" if quantize else model_name + print(f" {label}...", file=sys.stderr) + traced = load_and_trace_hf_model(model_name, quantize=quantize) if traced is None: print(f" FAILED (could not load/trace)", file=sys.stderr) return False @@ -151,14 +153,15 @@ def main(): results: dict[str, bool] = {} - with open(args.config) as f: - models = json.load(f) + models = load_model_config(args.config) + print(f"Validating {len(models)} HuggingFace models from " f"{args.config.name}...", file=sys.stderr) - for arch, model_id in models.items(): + for arch, spec in models.items(): results[arch] = validate_model( - model_id, allowed, forbidden, args.verbose) + spec["model_id"], allowed, forbidden, args.verbose, + quantize=spec["quantized"]) if args.pt_dir and args.pt_dir.is_dir(): pt_files = sorted(args.pt_dir.glob("*.pt")) @@ -178,7 +181,11 @@ def main(): if key.startswith("pt:"): print(f" {key}: {status}", file=sys.stderr) else: - print(f" {key} ({models[key]}): {status}", file=sys.stderr) + spec = models[key] + label = spec["model_id"] + if spec["quantized"]: + label += " (quantized)" + print(f" {key} ({label}): {status}", file=sys.stderr) print("=" * 60, file=sys.stderr) if all_pass: diff --git a/dev-tools/extract_model_ops/validation_models.json b/dev-tools/extract_model_ops/validation_models.json index 5c23eb907..0c853cdc5 100644 --- a/dev-tools/extract_model_ops/validation_models.json +++ b/dev-tools/extract_model_ops/validation_models.json @@ -19,6 +19,10 @@ "elastic-splade-v3": "elastic/splade-v3", "elastic-test-elser-v2": "elastic/test-elser-v2", + "elastic-elser-v2-quantized": {"model_id": "elastic/elser-v2", "quantized": true}, + "elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantized": true}, + "elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true}, + "ner-dslim-bert-base": "dslim/bert-base-NER", "sentiment-distilbert-sst2": "distilbert-base-uncased-finetuned-sst-2-english",