Skip to content

Commit d3df09c

Browse files
authored
[ML] Harden pytorch_inference with TorchScript model graph validation (#2999)
Re-applies #2936 and #2991 which were reverted in #2995. - Adds a static TorchScript graph validation layer (CModelGraphValidator, CSupportedOperations) that rejects models containing operations not observed in supported transformer architectures, reducing the attack surface by ensuring only known-safe operation sets are permitted. - Includes aten::mul_ and quantized::linear_dynamic in the allowed operations for dynamically quantized models (e.g. ELSER v2 imported via Eland). - Adds Python extraction tooling (dev-tools/extract_model_ops/) to trace reference HuggingFace models and collect their op sets, with support for quantized variants. - Adds reference_model_ops.json golden file and C++ drift test to detect allowlist staleness on PyTorch upgrades. - Adds adversarial "evil model" integration tests to verify rejection of forbidden operations. - Adds CHANGELOG entry.
1 parent 84d2d91 commit d3df09c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+3449
-181
lines changed

.buildkite/pipeline.json.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def main():
5252
if config.build_x86_64:
5353
pipeline_steps.append(pipeline_steps.generate_step("Upload ES tests x86_64 runner pipeline",
5454
".buildkite/pipelines/run_es_tests_x86_64.yml.sh"))
55+
pipeline_steps.append(pipeline_steps.generate_step("Upload ES inference tests x86_64 runner pipeline",
56+
".buildkite/pipelines/run_es_inference_tests_x86_64.yml.sh"))
5557
# We only use linux x86_64 builds for QA tests.
5658
if config.run_qa_tests:
5759
pipeline_steps.append(pipeline_steps.generate_step("Upload QA tests runner pipeline",
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#!/bin/bash
2+
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
# or more contributor license agreements. Licensed under the Elastic License
4+
# 2.0 and the following additional limitation. Functionality enabled by the
5+
# files subject to the Elastic License 2.0 may only be used in production when
6+
# invoked by an Elasticsearch process with a license key installed that permits
7+
# use of machine learning features. You may not use this file except in
8+
# compliance with the Elastic License 2.0 and the foregoing additional
9+
# limitation.
10+
11+
cat <<EOL
12+
steps:
13+
- label: "Java :java: Inference Integration Tests for x86_64 :hammer:"
14+
key: "java_inference_tests_x86_64"
15+
command:
16+
- 'sudo rpm --import https://yum.corretto.aws/corretto.key'
17+
- 'sudo curl -L -o /etc/yum.repos.d/corretto.repo https://yum.corretto.aws/corretto.repo'
18+
- 'sudo dnf install -y java-21-amazon-corretto-devel'
19+
- 'buildkite-agent artifact download "build/*" . --step build_test_linux-x86_64-RelWithDebInfo'
20+
- '.buildkite/scripts/steps/run_es_inference_tests.sh || (cd ../elasticsearch && find x-pack -name logs | xargs tar cvzf logs.tgz && buildkite-agent artifact upload logs.tgz && false)'
21+
depends_on: "build_test_linux-x86_64-RelWithDebInfo"
22+
agents:
23+
provider: aws
24+
instanceType: m6i.2xlarge
25+
imagePrefix: core-amazonlinux-2023
26+
diskSizeGb: 100
27+
diskName: '/dev/xvda'
28+
env:
29+
IVY_REPO: "../ivy"
30+
GRADLE_JVM_OPTS: "-Dorg.gradle.jvmargs=-Xmx16g"
31+
notify:
32+
- github_commit_status:
33+
context: "Java Inference Integration Tests for x86_64"
34+
EOL

.buildkite/pipelines/run_pytorch_tests.yml.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
# compliance with the Elastic License 2.0 and the foregoing additional
99
# limitation.
1010

11+
SAFE_MESSAGE=$(printf '%s' "${BUILDKITE_MESSAGE}" | head -1 | sed -e 's/\\/\\\\/g' -e 's/"/\\"/g')
12+
1113
cat <<EOL
1214
steps:
1315
- label: "Trigger Appex PyTorch Tests :test_tube:"
@@ -22,7 +24,7 @@ steps:
2224
- trigger: appex-qa-stateful-custom-ml-cpp-build-testing
2325
async: false
2426
build:
25-
message: "${BUILDKITE_MESSAGE}"
27+
message: "${SAFE_MESSAGE}"
2628
env:
2729
QAF_TESTS_TO_RUN: "pytorch_tests"
2830
EOL

.buildkite/pipelines/run_qa_tests.yml.sh

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
# compliance with the Elastic License 2.0 and the foregoing additional
99
# limitation.
1010

11+
SAFE_MESSAGE=$(printf '%s' "${BUILDKITE_MESSAGE}" | head -1 | sed -e 's/\\/\\\\/g' -e 's/"/\\"/g')
12+
1113
cat <<EOL
1214
steps:
1315
- label: "Trigger Appex QA Tests :test_tube:"
@@ -22,13 +24,7 @@ steps:
2224
- trigger: appex-qa-stateful-custom-ml-cpp-build-testing
2325
async: false
2426
build:
25-
message: |
26-
EOL
27-
28-
# Output the message with proper indentation for YAML literal block scalar
29-
printf '%s\n' "${BUILDKITE_MESSAGE}" | sed 's/^/ /'
30-
31-
cat <<EOL
27+
message: "${SAFE_MESSAGE}"
3228
env:
3329
QAF_TESTS_TO_RUN: "${QAF_TESTS_TO_RUN:-ml_cpp_pr}"
3430
EOL
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/bin/bash
2+
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
# or more contributor license agreements. Licensed under the Elastic License
4+
# 2.0 and the following additional limitation. Functionality enabled by the
5+
# files subject to the Elastic License 2.0 may only be used in production when
6+
# invoked by an Elasticsearch process with a license key installed that permits
7+
# use of machine learning features. You may not use this file except in
8+
# compliance with the Elastic License 2.0 and the foregoing additional
9+
# limitation.
10+
11+
set -euo pipefail
12+
13+
echo "pwd = $(pwd)"
14+
15+
export HARDWARE_ARCH=$(uname -m | sed 's/arm64/aarch64/')
16+
17+
VERSION=$(cat ${REPO_ROOT}/gradle.properties | grep '^elasticsearchVersion' | awk -F= '{ print $2 }' | xargs echo)
18+
if [ "${BUILD_SNAPSHOT:=true}" = "true" ] ; then
19+
VERSION=${VERSION}-SNAPSHOT
20+
fi
21+
export VERSION
22+
23+
export PR_AUTHOR=$(expr "$BUILDKITE_BRANCH" : '\(.*\):.*')
24+
export PR_SOURCE_BRANCH=$(expr "$BUILDKITE_BRANCH" : '.*:\(.*\)')
25+
export PR_TARGET_BRANCH=${BUILDKITE_PULL_REQUEST_BASE_BRANCH}
26+
27+
mkdir -p "${IVY_REPO}/maven/org/elasticsearch/ml/ml-cpp/$VERSION"
28+
cp "build/distributions/ml-cpp-$VERSION-linux-$HARDWARE_ARCH.zip" "${IVY_REPO}/maven/org/elasticsearch/ml/ml-cpp/$VERSION/ml-cpp-$VERSION.zip"
29+
cp "build/distributions/ml-cpp-$VERSION-linux-$HARDWARE_ARCH.zip" "${IVY_REPO}/maven/org/elasticsearch/ml/ml-cpp/$VERSION/ml-cpp-$VERSION-nodeps.zip"
30+
cp dev-tools/minimal.zip "${IVY_REPO}/maven/org/elasticsearch/ml/ml-cpp/$VERSION/ml-cpp-$VERSION-deps.zip"
31+
./dev-tools/run_es_inference_tests.sh ".." "$(cd "${IVY_REPO}" && pwd)"

bin/pytorch_inference/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ ml_add_executable(pytorch_inference
3535
CBufferedIStreamAdapter.cc
3636
CCmdLineParser.cc
3737
CCommandParser.cc
38+
CModelGraphValidator.cc
3839
CResultWriter.cc
40+
CSupportedOperations.cc
3941
CThreadSettings.cc
4042
)
4143

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the following additional limitation. Functionality enabled by the
5+
* files subject to the Elastic License 2.0 may only be used in production when
6+
* invoked by an Elasticsearch process with a license key installed that permits
7+
* use of machine learning features. You may not use this file except in
8+
* compliance with the Elastic License 2.0 and the foregoing additional
9+
* limitation.
10+
*/
11+
12+
#include "CModelGraphValidator.h"
13+
14+
#include "CSupportedOperations.h"
15+
16+
#include <core/CLogger.h>
17+
18+
#include <torch/csrc/jit/passes/inliner.h>
19+
20+
#include <algorithm>
21+
22+
namespace ml {
23+
namespace torch {
24+
25+
CModelGraphValidator::SResult CModelGraphValidator::validate(const ::torch::jit::Module& module) {
26+
27+
TStringSet observedOps;
28+
std::size_t nodeCount{0};
29+
collectModuleOps(module, observedOps, nodeCount);
30+
31+
if (nodeCount > MAX_NODE_COUNT) {
32+
LOG_ERROR(<< "Model graph is too large: " << nodeCount
33+
<< " nodes exceeds limit of " << MAX_NODE_COUNT);
34+
return {false, {}, {}, nodeCount};
35+
}
36+
37+
LOG_DEBUG(<< "Model graph contains " << observedOps.size()
38+
<< " distinct operations across " << nodeCount << " nodes");
39+
for (const auto& op : observedOps) {
40+
LOG_DEBUG(<< " observed op: " << op);
41+
}
42+
43+
auto result = validate(observedOps, CSupportedOperations::ALLOWED_OPERATIONS,
44+
CSupportedOperations::FORBIDDEN_OPERATIONS);
45+
result.s_NodeCount = nodeCount;
46+
return result;
47+
}
48+
49+
CModelGraphValidator::SResult
50+
CModelGraphValidator::validate(const TStringSet& observedOps,
51+
const std::unordered_set<std::string_view>& allowedOps,
52+
const std::unordered_set<std::string_view>& forbiddenOps) {
53+
54+
SResult result;
55+
56+
// Two-pass check: forbidden ops first, then unrecognised. This lets us
57+
// fail fast when a known-dangerous operation is present and avoids the
58+
// cost of scanning for unrecognised ops on a model we will reject anyway.
59+
for (const auto& op : observedOps) {
60+
if (forbiddenOps.contains(op)) {
61+
result.s_IsValid = false;
62+
result.s_ForbiddenOps.push_back(op);
63+
}
64+
}
65+
66+
if (result.s_ForbiddenOps.empty()) {
67+
for (const auto& op : observedOps) {
68+
if (allowedOps.contains(op) == false) {
69+
result.s_IsValid = false;
70+
result.s_UnrecognisedOps.push_back(op);
71+
}
72+
}
73+
}
74+
75+
std::sort(result.s_ForbiddenOps.begin(), result.s_ForbiddenOps.end());
76+
std::sort(result.s_UnrecognisedOps.begin(), result.s_UnrecognisedOps.end());
77+
78+
return result;
79+
}
80+
81+
void CModelGraphValidator::collectBlockOps(const ::torch::jit::Block& block,
82+
TStringSet& ops,
83+
std::size_t& nodeCount) {
84+
for (const auto* node : block.nodes()) {
85+
if (++nodeCount > MAX_NODE_COUNT) {
86+
return;
87+
}
88+
ops.emplace(node->kind().toQualString());
89+
for (const auto* subBlock : node->blocks()) {
90+
collectBlockOps(*subBlock, ops, nodeCount);
91+
if (nodeCount > MAX_NODE_COUNT) {
92+
return;
93+
}
94+
}
95+
}
96+
}
97+
98+
void CModelGraphValidator::collectModuleOps(const ::torch::jit::Module& module,
99+
TStringSet& ops,
100+
std::size_t& nodeCount) {
101+
for (const auto& method : module.get_methods()) {
102+
// Inline all method calls so that operations hidden behind
103+
// prim::CallMethod are surfaced. After inlining, any remaining
104+
// prim::CallMethod indicates a call that could not be resolved
105+
// statically and will be flagged as unrecognised.
106+
auto graph = method.graph()->copy();
107+
::torch::jit::Inline(*graph);
108+
collectBlockOps(*graph->block(), ops, nodeCount);
109+
if (nodeCount > MAX_NODE_COUNT) {
110+
return;
111+
}
112+
}
113+
}
114+
}
115+
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the following additional limitation. Functionality enabled by the
5+
* files subject to the Elastic License 2.0 may only be used in production when
6+
* invoked by an Elasticsearch process with a license key installed that permits
7+
* use of machine learning features. You may not use this file except in
8+
* compliance with the Elastic License 2.0 and the foregoing additional
9+
* limitation.
10+
*/
11+
12+
#ifndef INCLUDED_ml_torch_CModelGraphValidator_h
13+
#define INCLUDED_ml_torch_CModelGraphValidator_h
14+
15+
#include <torch/script.h>
16+
17+
#include <string>
18+
#include <string_view>
19+
#include <unordered_set>
20+
#include <vector>
21+
22+
namespace ml {
23+
namespace torch {
24+
25+
//! \brief
26+
//! Validates TorchScript model computation graphs against a set of
27+
//! allowed operations.
28+
//!
29+
//! DESCRIPTION:\n
30+
//! Provides defense-in-depth by statically inspecting the TorchScript
31+
//! graph of a loaded model and rejecting any model that contains
32+
//! operations not present in the allowlist derived from supported
33+
//! transformer architectures.
34+
//!
35+
//! IMPLEMENTATION DECISIONS:\n
36+
//! The validation walks all methods of the module and its submodules
37+
//! recursively, collecting every distinct operation. Any operation
38+
//! that appears in the forbidden set causes immediate rejection.
39+
//! Any operation not in the allowed set is collected and reported.
40+
//! This ensures that even operations buried in helper methods or
41+
//! nested submodules are inspected.
42+
//!
43+
class CModelGraphValidator {
44+
public:
45+
using TStringSet = std::unordered_set<std::string>;
46+
using TStringVec = std::vector<std::string>;
47+
48+
//! Upper bound on the number of graph nodes we are willing to inspect.
49+
//! Transformer models typically have O(10k) nodes after inlining; a
50+
//! limit of 1M provides generous headroom while preventing a
51+
//! pathologically large graph from consuming unbounded memory or CPU.
52+
static constexpr std::size_t MAX_NODE_COUNT{1000000};
53+
54+
//! Result of validating a model graph.
55+
struct SResult {
56+
bool s_IsValid{true};
57+
TStringVec s_ForbiddenOps;
58+
TStringVec s_UnrecognisedOps;
59+
std::size_t s_NodeCount{0};
60+
};
61+
62+
public:
63+
//! Validate the computation graph of the given module against the
64+
//! supported operation allowlist. Recursively inspects all methods
65+
//! across all submodules.
66+
static SResult validate(const ::torch::jit::Module& module);
67+
68+
//! Validate a pre-collected set of operation names. Useful for
69+
//! unit testing the matching logic without requiring a real model.
70+
static SResult validate(const TStringSet& observedOps,
71+
const std::unordered_set<std::string_view>& allowedOps,
72+
const std::unordered_set<std::string_view>& forbiddenOps);
73+
74+
private:
75+
//! Collect all operation names from a block, recursing into sub-blocks.
76+
static void collectBlockOps(const ::torch::jit::Block& block,
77+
TStringSet& ops,
78+
std::size_t& nodeCount);
79+
80+
//! Inline all method calls and collect ops from the flattened graph.
81+
//! After inlining, prim::CallMethod should not appear; if it does,
82+
//! the call could not be resolved statically and is treated as
83+
//! unrecognised.
84+
static void collectModuleOps(const ::torch::jit::Module& module,
85+
TStringSet& ops,
86+
std::size_t& nodeCount);
87+
};
88+
}
89+
}
90+
91+
#endif // INCLUDED_ml_torch_CModelGraphValidator_h

0 commit comments

Comments
 (0)