Skip to content

[ML] Harden pytorch_inference with TorchScript model graph validation#2936

Merged
edsavage merged 30 commits intoelastic:mainfrom
edsavage:feature/harden_pytorch_inference
Mar 12, 2026
Merged

[ML] Harden pytorch_inference with TorchScript model graph validation#2936
edsavage merged 30 commits intoelastic:mainfrom
edsavage:feature/harden_pytorch_inference

Conversation

@edsavage
Copy link
Contributor

@edsavage edsavage commented Mar 2, 2026

Summary

Implements security hardening for pytorch_inference by validating TorchScript model graphs before execution, addressing elastic/ml-team#1770.

Model Graph Validation (C++)

  • CModelGraphValidator: Validates TorchScript model graphs by inlining all method calls (torch::jit::Inline) and recursively inspecting every node (including sub-blocks inside prim::If/prim::Loop).
  • CSupportedOperations: Defines a dual-list security model:
    • FORBIDDEN_OPERATIONS (4 ops): aten::execute_with_args, aten::from_file, prim::CallFunction, prim::CallMethod — rejected immediately with a clear error.
    • ALLOWED_OPERATIONS (82 ops): Exhaustive allowlist of safe tensor/control-flow ops derived from tracing reference models against PyTorch 2.7.1.
  • Maximum node count (MAX_NODE_COUNT = 1,000,000): Guards against resource exhaustion from excessively large graphs.
  • Debug logging: Observed ops are logged at DEBUG level during validation.

Operation Allowlist Tooling (Python)

  • dev-tools/extract_model_ops/: Self-contained tooling directory with:
    • extract_model_ops.py — generates the C++ allowlist from reference HuggingFace models
    • validate_allowlist.py — integration test verifying no false positives against 24 HuggingFace models + 3 Elasticsearch integration test models
    • reference_models.json / validation_models.json — model configurations
    • es_it_models/ — extracted .pt models from Elasticsearch's PyTorchModelIT, TextExpansionQueryIT, TextEmbeddingQueryIT
    • requirements.txt — pinned to torch==2.7.1 matching the libtorch build version

CI Integration

  • cmake/run-validation.cmake: Portable CMake script that locates Python 3 (searching python3, python3.12, ..., python), manages a virtual environment, handles DYLD_LIBRARY_PATH/LD_LIBRARY_PATH for libtorch conflicts, and runs the validation. Supports OPTIONAL=TRUE for graceful skip when Python or network is unavailable.
  • Wired into test_all_parallel and precommit with OPTIONAL=TRUE — runs automatically when Python is available, skips with a warning otherwise (e.g. in Docker containers without network).
  • Standalone target validate_pytorch_inference_models available for explicit verification (hard failure mode).

C++ Tests

  • Unit tests (CModelGraphValidatorTest.cc): Tests for allowed/forbidden/unrecognised ops, graph inlining, node count enforcement, and integration tests using torch::jit::Module::define().
  • Malicious model tests: 6 generated .pt fixtures testing detection of aten::from_file, hidden ops in submodules, conditional branches, and mixed scenarios.

Test plan

  • All existing C++ unit tests pass (cmake --build ... -t test)
  • New CModelGraphValidator tests pass (50 pytorch_inference test cases)
  • Malicious model fixtures correctly rejected
  • Python validation passes for all 27 models (24 HuggingFace + 3 ES .pt)
  • OPTIONAL=TRUE gracefully skips when Python unavailable
  • CI passes on all platforms (Linux x86_64, Linux aarch64, macOS aarch64, Windows x86_64)
  • CI passes in both RelWithDebInfo and Debug configurations

Made with Cursor

edsavage added 17 commits March 2, 2026 10:54
Add a static TorchScript graph validation layer that rejects models
containing operations not observed in supported transformer architectures.
This reduces the attack surface by ensuring only known-safe operation
sets are permitted, complementing the existing Sandbox2/seccomp defenses.

New files:
- CSupportedOperations: allowlist of 71 ops from 10 reference architectures
- CModelGraphValidator: recursive graph walker and validation logic
- CModelGraphValidatorTest: 10 unit tests covering pass/fail/edge cases
- extract_model_ops.py: developer tool to regenerate the allowlist

Relates to elastic/ml-team#1770

Made-with: Cursor
…onfig

- Move script to dev-tools/extract_model_ops/ subdirectory
- Extract REFERENCE_MODELS dict to reference_models.json config file
- Add requirements.txt for virtual environment setup
- Add README.md with setup, usage, and configuration instructions
- Update CSupportedOperations path references

Made-with: Cursor
…ript

- Add all 10 elastic/* models from HuggingFace to reference_models.json
- Make extract_model_ops.py resilient to individual model load/trace
  failures (continues to next model instead of crashing)
- Add sentencepiece and protobuf to requirements.txt
- Add .gitignore for .venv directory
- Update CSupportedOperations.cc comment with expanded model list
- Op union remains 71 ops (Elastic models use same base architectures)

Made-with: Cursor
Remove bart and elastic/multilingual-e5-small which cannot be
traced or scripted with the current transformers/torch versions.

Made-with: Cursor
Explain why both a short forbidden list and a broad allowed list are
maintained: targeted error messages, safety net against accidental
allowlist expansion, and defence-in-depth.

Made-with: Cursor
Re-ran extraction with torch 2.7.1 (matching the libtorch version
linked by ml-cpp) -- op set is identical to the 2.10.0 run.  Pin
torch version in requirements.txt and fix the comment.

Made-with: Cursor
Aids debugging when a legitimate model is unexpectedly rejected after
a PyTorch upgrade, and provides an audit trail of what was loaded.

Made-with: Cursor
…Method

Use torch::jit::Inline() to flatten method calls before collecting
operations.  This ensures ops hidden behind prim::CallMethod are
surfaced for validation.  After inlining, prim::CallMethod and
prim::CallFunction should not appear; add them to the forbidden
list so any unresolvable call is explicitly rejected.

Made-with: Cursor
Reject models whose inlined computation graph exceeds 1M nodes.
Typical transformer models have O(10k) nodes; the generous limit
prevents pathologically crafted models from causing excessive memory
or CPU usage during graph traversal.

Made-with: Cursor
Construct scriptable modules with define() and validate them through
the full CModelGraphValidator pipeline.  Covers: a valid module with
allowed ops, a module with unrecognised ops, node count tracking, and
a parent/child module pair that exercises graph inlining.

Made-with: Cursor
Adds validate_allowlist.py alongside extract_model_ops.py in
dev-tools/extract_model_ops/.  The script parses ALLOWED_OPERATIONS
and FORBIDDEN_OPERATIONS directly from CSupportedOperations.cc, then
traces every model in validation_models.json and checks for false
positives.

validation_models.json is a superset of reference_models.json that
also includes task-specific models (NER, sentiment analysis) matching
the bin/pytorch_inference/examples/ test data.

A wrapper script (run_validation.sh) automatically creates the Python
venv and installs dependencies on first run.  A CMake target is
registered for convenient invocation:
  cmake --build <build-dir> -t validate_pytorch_inference_models

Made-with: Cursor
Extend the allowlist validation to cover models directly referenced in
the Elasticsearch repo and its eland import tool: the packaged
multilingual-e5-small, the cross-encoder reranker from the docs, the
sentence-transformers embedding model from eland tests, and the DPR
question encoder. All 24 models pass validation with no false positives.

Made-with: Cursor
Extract the base64-encoded TorchScript models from PyTorchModelIT,
TextExpansionQueryIT, and TextEmbeddingQueryIT in the Elasticsearch
repo and validate them against our operation allowlist. These toy
models use basic ops (aten::ones, aten::rand, aten::hash, prim::Loop,
etc.) that weren't in the transformer-derived allowlist, so add them.
All are safe tensor/control-flow operations with no I/O capability.

The validation script now accepts --pt-dir to validate pre-saved .pt
files alongside HuggingFace models. The CMake target passes the new
es_it_models directory automatically.

Made-with: Cursor
Create six malicious .pt model fixtures that exercise specific attack
vectors the CModelGraphValidator must detect:

- malicious_file_reader: uses aten::from_file to read arbitrary files
- malicious_mixed_file_reader: hides aten::from_file among allowed ops
- malicious_hidden_in_submodule: buries unrecognised ops 3 levels deep
- malicious_conditional: hides unrecognised ops inside if-branches
- malicious_many_unrecognised: uses sin/cos/tan/exp (unknown arch)
- malicious_file_reader_in_submodule: forbidden op hidden in child module

Each test loads the real .pt file via torch::jit::load and verifies the
validator correctly identifies and rejects it. Includes the Python
generator script for reproducibility.

Made-with: Cursor
Replace the bash wrapper script with cmake/run-validation.cmake that
works across all CI platforms (Linux, macOS, Windows). The CMake script
searches for python3, python3.12, python3.11, python3.10, python3.9,
and python — handling Linux build machines where Python is only
available as python3.12 (via make altinstall) and Windows where the
canonical name is python. It also prepends the venv's torch/lib
directory to the dynamic library search path to avoid conflicts with
any system-installed libtorch.

Made-with: Cursor
Add the Python allowlist validation as a step in test_all_parallel
(used by CI) and precommit (used by developers). Both use OPTIONAL=TRUE
so the validation is gracefully skipped with a warning when Python 3 is
not available or pip cannot install dependencies (e.g. in Docker
containers without network access). The standalone
validate_pytorch_inference_models target remains a hard failure for
explicit use.

Made-with: Cursor
@prodsecmachine
Copy link

prodsecmachine commented Mar 2, 2026

Snyk checks have passed. No issues have been found so far.

Status Scanner Critical High Medium Low Total (0)
Open Source Security 0 0 0 0 0 issues
Licenses 0 0 0 0 0 issues

💻 Catch issues earlier using the plugins for VS Code, JetBrains IDEs, Visual Studio, and Eclipse.

edsavage added 6 commits March 3, 2026 10:37
Replace relative "../Foo.h" includes with <Foo.h> by adding the parent
source directory to the test target's include path. Also remove
unnecessary backslash escapes in extract_model_ops README.

Made-with: Cursor
Deduplicate collect_graph_ops, graph inlining, and HuggingFace model
loading/tracing logic shared between extract_model_ops.py and
validate_allowlist.py into a common module.

Made-with: Cursor
@edsavage edsavage marked this pull request as ready for review March 9, 2026 20:40
@edsavage edsavage requested a review from valeriy42 March 9, 2026 20:40
- Check MAX_NODE_COUNT during graph traversal to prevent resource
  exhaustion on pathologically large models (bail out immediately
  in collectBlockOps and collectModuleOps).
- Two-pass validation: check forbidden ops first, skip unrecognised
  op scan when forbidden ops are found.
- Add aten::as_strided to FORBIDDEN_OPERATIONS (key enabler of
  heap-leak and ROP chain attacks).
- Change LOG_FATAL to HANDLE_FATAL in the c10::Error catch block
  so an exception during validation terminates the process.
- Fix CHANGELOG asciidoc link syntax.
- Move generate_malicious_models.py to dev-tools/.
- Remove redundant Python test scripts now that C++ integration
  tests cover the same attack models.
- Remove PR cross-references from comments per reviewer request.

Made-with: Cursor
Add a C++ test (testAllowlistCoversReferenceModels) that loads a
golden JSON file containing per-architecture TorchScript op sets
extracted from 18 reference HuggingFace models and verifies every
op is in ALLOWED_OPERATIONS and none are in FORBIDDEN_OPERATIONS.

This catches allowlist regressions in CI without requiring Python
or network access.  When PyTorch is upgraded, regenerate the golden
file with:

  python3 extract_model_ops.py --golden \
    bin/pytorch_inference/unittest/testfiles/reference_model_ops.json

The --golden flag is a new addition to extract_model_ops.py that
outputs per-model op sets as structured JSON.

Made-with: Cursor
@edsavage edsavage requested a review from valeriy42 March 12, 2026 03:26
Copy link
Contributor

@valeriy42 valeriy42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Great work!

I left only one comment regarding automatic running against the PyTorch edge branch.

@valeriy42 valeriy42 added v9.2.7 v9.3.2 v8.19.13 auto-backport Automatically merge backport PRs when CI passes labels Mar 12, 2026
@valeriy42
Copy link
Contributor

@edsavage , although it's an enhancement, for obvious reasons, I added backporting to all supported versions.

@edsavage
Copy link
Contributor Author

I left only one comment regarding automatic running against the PyTorch edge branch.

Yes, that's a really good idea, I'll look into doing that in a separate PR.

@edsavage edsavage merged commit 38f6653 into elastic:main Mar 12, 2026
19 checks passed
@github-actions
Copy link

💔 All backports failed

Status Branch Result
9.2 Backport failed because of merge conflicts

You might need to backport the following PRs to 9.2:
- test: create bot PR for self-approve test (#2976)
- test: verify backport with auto-backport label (#2951)
- test: verify backport without auto-backport label (#2948)
- [ML] Add resource monitoring in CBucketGatherer::addEventData (#2848)
- [ML] Upgrade to PyTorch 2.7.1 (#2863)
- [ML] Report the "actual" memory usage of the autodetect process (#2846)
- [ML] Add a script to run each unit test separately (#2859)
9.3 Backport failed because of merge conflicts

You might need to backport the following PRs to 9.3:
- test: verify backport with auto-backport label (#2951)
- test: verify backport without auto-backport label (#2948)
8.19 Backport failed because of merge conflicts

Manual backport

To create the backport manually run:

backport --pr 2936

Questions ?

Please refer to the Backport tool documentation and see the Github Action logs for details

edsavage added a commit to edsavage/ml-cpp that referenced this pull request Mar 12, 2026
…elastic#2936)

Add a static TorchScript graph validation layer that rejects models
containing operations not observed in supported transformer architectures.
This reduces the attack surface by ensuring only known-safe operation
sets are permitted, complementing the existing Sandbox2/seccomp defenses.

(cherry picked from commit 38f6653)
edsavage added a commit to edsavage/ml-cpp that referenced this pull request Mar 12, 2026
…elastic#2936)

Add a static TorchScript graph validation layer that rejects models
containing operations not observed in supported transformer architectures.
This reduces the attack surface by ensuring only known-safe operation
sets are permitted, complementing the existing Sandbox2/seccomp defenses.

(cherry picked from commit 38f6653)
edsavage added a commit to edsavage/ml-cpp that referenced this pull request Mar 12, 2026
…elastic#2936)

Add a static TorchScript graph validation layer that rejects models
containing operations not observed in supported transformer architectures.
This reduces the attack surface by ensuring only known-safe operation
sets are permitted, complementing the existing Sandbox2/seccomp defenses.

(cherry picked from commit 38f6653)
edsavage added a commit that referenced this pull request Mar 12, 2026
…idation (#2936) (#2988)

Add a static TorchScript graph validation layer that rejects models
containing operations not observed in supported transformer architectures.
This reduces the attack surface by ensuring only known-safe operation
sets are permitted, complementing the existing Sandbox2/seccomp defenses.

Backports #2936
edsavage added a commit that referenced this pull request Mar 12, 2026
…#2936) (#2987)

Add a static TorchScript graph validation layer that rejects models
containing operations not observed in supported transformer architectures.
This reduces the attack surface by ensuring only known-safe operation
sets are permitted, complementing the existing Sandbox2/seccomp defenses.

Backports #2936
edsavage added a commit that referenced this pull request Mar 13, 2026
…#2936) (#2986)

Add a static TorchScript graph validation layer that rejects models
containing operations not observed in supported transformer architectures.
This reduces the attack surface by ensuring only known-safe operation
sets are permitted, complementing the existing Sandbox2/seccomp defenses.

Backports #2936
valeriy42 added a commit to valeriy42/ml-cpp that referenced this pull request Mar 13, 2026
valeriy42 added a commit that referenced this pull request Mar 13, 2026
* Revert "[ML] Add quantized model ops to pytorch_inference allowlist (#2991)"

This reverts commit 92432d6.

* Revert "[ML] Harden pytorch_inference with TorchScript model graph validation (#2936)"

This reverts commit 38f6653.

* fix run_qa_tests buildkite step
edsavage added a commit to edsavage/ml-cpp that referenced this pull request Mar 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants