Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions bin/pytorch_inference/CSupportedOperations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA
// elastic/test-elser-v2.
// Additional ops from Elasticsearch integration test models
// (PyTorchModelIT, TextExpansionQueryIT, TextEmbeddingQueryIT).
// Quantized operations from dynamically quantized variants of the above
// models (torch.quantization.quantize_dynamic on nn.Linear layers).
const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATIONS = {
// aten operations (core tensor computations)
"aten::Int"sv,
Expand Down Expand Up @@ -79,6 +81,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
"aten::mean"sv,
"aten::min"sv,
"aten::mul"sv,
"aten::mul_"sv,
"aten::ne"sv,
"aten::neg"sv,
"aten::new_ones"sv,
Expand Down Expand Up @@ -124,6 +127,8 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
"prim::dtype"sv,
"prim::max"sv,
"prim::min"sv,
// quantized operations (dynamically quantized models, e.g. ELSER v2)
"quantized::linear_dynamic"sv,
};
}
}
40 changes: 40 additions & 0 deletions bin/pytorch_inference/unittest/testfiles/reference_model_ops.json
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@
},
"elastic-eis-elser-v2": {
"model_id": "elastic/eis-elser-v2",
"quantized": false,
"ops": [
"aten::Int",
"aten::ScalarImplicit",
Expand Down Expand Up @@ -303,6 +304,7 @@
},
"elastic-elser-v2": {
"model_id": "elastic/elser-v2",
"quantized": false,
"ops": [
"aten::Int",
"aten::ScalarImplicit",
Expand Down Expand Up @@ -337,6 +339,44 @@
"prim::NumToTensor"
]
},
"elastic-elser-v2-quantized": {
"model_id": "elastic/elser-v2",
"quantized": true,
"ops": [
"aten::Int",
"aten::ScalarImplicit",
"aten::__and__",
"aten::add",
"aten::arange",
"aten::contiguous",
"aten::dropout",
"aten::embedding",
"aten::expand",
"aten::gather",
"aten::ge",
"aten::gelu",
"aten::index",
"aten::layer_norm",
"aten::mul_",
"aten::new_ones",
"aten::reshape",
"aten::scaled_dot_product_attention",
"aten::select",
"aten::size",
"aten::slice",
"aten::tanh",
"aten::to",
"aten::transpose",
"aten::unsqueeze",
"aten::view",
"prim::Constant",
"prim::DictConstruct",
"prim::GetAttr",
"prim::ListConstruct",
"prim::NumToTensor",
"quantized::linear_dynamic"
]
},
"elastic-hugging-face-elser": {
"model_id": "elastic/hugging-face-elser",
"ops": [
Expand Down
36 changes: 21 additions & 15 deletions dev-tools/extract_model_ops/extract_model_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,25 @@

import torch

from torchscript_utils import collect_inlined_ops, load_and_trace_hf_model
from torchscript_utils import (
collect_inlined_ops,
load_and_trace_hf_model,
load_model_config,
)

SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_CONFIG = SCRIPT_DIR / "reference_models.json"


def load_reference_models(config_path: Path) -> dict[str, str]:
"""Load the architecture-to-model mapping from a JSON config file."""
with open(config_path) as f:
return json.load(f)


def extract_ops_for_model(model_name: str) -> set[str] | None:
def extract_ops_for_model(model_name: str,
quantize: bool = False) -> set[str] | None:
"""Trace a HuggingFace model and return its TorchScript op set.

Returns None if the model could not be loaded or traced.
"""
print(f" Loading {model_name}...", file=sys.stderr)
traced = load_and_trace_hf_model(model_name)
label = f"{model_name} (quantized)" if quantize else model_name
print(f" Loading {label}...", file=sys.stderr)
traced = load_and_trace_hf_model(model_name, quantize=quantize)
if traced is None:
return None
return collect_inlined_ops(traced)
Expand Down Expand Up @@ -81,7 +81,7 @@ def main():
help="Path to reference_models.json config file")
args = parser.parse_args()

reference_models = load_reference_models(args.config)
reference_models = load_model_config(args.config)

per_model_ops = {}
union_ops = set()
Expand All @@ -90,8 +90,9 @@ def main():
file=sys.stderr)

failed = []
for arch, model_name in reference_models.items():
ops = extract_ops_for_model(model_name)
for arch, spec in reference_models.items():
ops = extract_ops_for_model(spec["model_id"],
quantize=spec["quantized"])
if ops is None:
failed.append(arch)
print(f" {arch}: FAILED", file=sys.stderr)
Expand All @@ -109,7 +110,8 @@ def main():
"pytorch_version": torch.__version__,
"models": {
arch: {
"model_id": reference_models[arch],
"model_id": reference_models[arch]["model_id"],
"quantized": reference_models[arch]["quantized"],
"ops": sorted(ops),
}
for arch, ops in sorted(per_model_ops.items())
Expand All @@ -125,7 +127,11 @@ def main():

if args.per_model:
for arch, ops in sorted(per_model_ops.items()):
print(f"\n=== {arch} ({reference_models[arch]}) ===")
spec = reference_models[arch]
label = spec["model_id"]
if spec["quantized"]:
label += " (quantized)"
print(f"\n=== {arch} ({label}) ===")
for op in sorted(ops):
print(f" {op}")

Expand Down
7 changes: 6 additions & 1 deletion dev-tools/extract_model_ops/reference_models.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,10 @@
"elastic-hugging-face-elser": "elastic/hugging-face-elser",
"elastic-multilingual-e5-small-optimized": "elastic/multilingual-e5-small-optimized",
"elastic-splade-v3": "elastic/splade-v3",
"elastic-test-elser-v2": "elastic/test-elser-v2"
"elastic-test-elser-v2": "elastic/test-elser-v2",

"_comment:quantized": "Quantized variants: Eland applies torch.quantization.quantize_dynamic on nn.Linear layers when importing models. These produce quantized::* ops not present in the standard traced graphs above.",
"elastic-elser-v2-quantized": {"model_id": "elastic/elser-v2", "quantized": true},
"elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantized": true},
"elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true}
}
55 changes: 54 additions & 1 deletion dev-tools/extract_model_ops/torchscript_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,52 @@
#
"""Shared utilities for extracting and inspecting TorchScript operations."""

import json
import os
import sys
from pathlib import Path

import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer


def load_model_config(config_path: Path) -> dict[str, dict]:
"""Load a model config JSON file and normalise entries.

Each entry is either a plain model-name string or a dict with
``model_id`` (required) and optional ``quantized`` boolean. All
entries are normalised to ``{"model_id": str, "quantized": bool}``.
Keys starting with ``_comment`` are silently skipped.

Raises ``ValueError`` for malformed entries so that config problems
are caught early with an actionable message.
"""
with open(config_path) as f:
raw = json.load(f)

models: dict[str, dict] = {}
for key, value in raw.items():
if key.startswith("_comment"):
continue
if isinstance(value, str):
models[key] = {"model_id": value, "quantized": False}
elif isinstance(value, dict):
if "model_id" not in value:
raise ValueError(
f"Config entry {key!r} is a dict but missing required "
f"'model_id' key: {value!r}")
models[key] = {
"model_id": value["model_id"],
"quantized": value.get("quantized", False),
}
else:
raise ValueError(
f"Config entry {key!r} has unsupported type "
f"{type(value).__name__}: {value!r}. "
f"Expected a model name string or a dict with 'model_id'.")
return models


def collect_graph_ops(graph) -> set[str]:
"""Collect all operation names from a TorchScript graph, including blocks."""
ops = set()
Expand All @@ -35,9 +74,13 @@ def collect_inlined_ops(module) -> set[str]:
return collect_graph_ops(graph)


def load_and_trace_hf_model(model_name: str):
def load_and_trace_hf_model(model_name: str, quantize: bool = False):
"""Load a HuggingFace model, tokenize sample input, and trace to TorchScript.

When *quantize* is True the model is dynamically quantized (nn.Linear
layers converted to quantized::linear_dynamic) before tracing. This
mirrors what Eland does when importing models for Elasticsearch.

Returns the traced module, or None if the model could not be loaded or traced.
"""
token = os.environ.get("HF_TOKEN")
Expand All @@ -53,6 +96,16 @@ def load_and_trace_hf_model(model_name: str):
print(f" LOAD ERROR: {exc}", file=sys.stderr)
return None

if quantize:
try:
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8)
print(" Applied dynamic quantization (nn.Linear -> qint8)",
file=sys.stderr)
except Exception as exc:
print(f" QUANTIZE ERROR: {exc}", file=sys.stderr)
return None

inputs = tokenizer(
"This is a sample input for graph extraction.",
return_tensors="pt", padding="max_length",
Expand Down
25 changes: 16 additions & 9 deletions dev-tools/extract_model_ops/validate_allowlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
"""

import argparse
import json
import re
import sys
from pathlib import Path
Expand All @@ -40,6 +39,7 @@
collect_graph_ops,
collect_inlined_ops,
load_and_trace_hf_model,
load_model_config,
)

SCRIPT_DIR = Path(__file__).resolve().parent
Expand Down Expand Up @@ -103,10 +103,12 @@ def check_ops(ops: set[str],
def validate_model(model_name: str,
allowed: set[str],
forbidden: set[str],
verbose: bool) -> bool:
verbose: bool,
quantize: bool = False) -> bool:
"""Validate one HuggingFace model. Returns True if all ops pass."""
print(f" {model_name}...", file=sys.stderr)
traced = load_and_trace_hf_model(model_name)
label = f"{model_name} (quantized)" if quantize else model_name
print(f" {label}...", file=sys.stderr)
traced = load_and_trace_hf_model(model_name, quantize=quantize)
if traced is None:
print(f" FAILED (could not load/trace)", file=sys.stderr)
return False
Expand Down Expand Up @@ -151,14 +153,15 @@ def main():

results: dict[str, bool] = {}

with open(args.config) as f:
models = json.load(f)
models = load_model_config(args.config)

print(f"Validating {len(models)} HuggingFace models from "
f"{args.config.name}...", file=sys.stderr)

for arch, model_id in models.items():
for arch, spec in models.items():
results[arch] = validate_model(
model_id, allowed, forbidden, args.verbose)
spec["model_id"], allowed, forbidden, args.verbose,
quantize=spec["quantized"])

if args.pt_dir and args.pt_dir.is_dir():
pt_files = sorted(args.pt_dir.glob("*.pt"))
Expand All @@ -178,7 +181,11 @@ def main():
if key.startswith("pt:"):
print(f" {key}: {status}", file=sys.stderr)
else:
print(f" {key} ({models[key]}): {status}", file=sys.stderr)
spec = models[key]
label = spec["model_id"]
if spec["quantized"]:
label += " (quantized)"
print(f" {key} ({label}): {status}", file=sys.stderr)

print("=" * 60, file=sys.stderr)
if all_pass:
Expand Down
4 changes: 4 additions & 0 deletions dev-tools/extract_model_ops/validation_models.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
"elastic-splade-v3": "elastic/splade-v3",
"elastic-test-elser-v2": "elastic/test-elser-v2",

"elastic-elser-v2-quantized": {"model_id": "elastic/elser-v2", "quantized": true},
"elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantized": true},
"elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true},

"ner-dslim-bert-base": "dslim/bert-base-NER",
"sentiment-distilbert-sst2": "distilbert-base-uncased-finetuned-sst-2-english",

Expand Down