diff --git a/.buildkite/pipelines/run_qa_tests.yml.sh b/.buildkite/pipelines/run_qa_tests.yml.sh index 2b4fbb9d4..b606267b9 100755 --- a/.buildkite/pipelines/run_qa_tests.yml.sh +++ b/.buildkite/pipelines/run_qa_tests.yml.sh @@ -22,7 +22,13 @@ steps: - trigger: appex-qa-stateful-custom-ml-cpp-build-testing async: false build: - message: "${BUILDKITE_MESSAGE}" + message: | +EOL + +# Output the message with proper indentation for YAML literal block scalar +printf '%s\n' "${BUILDKITE_MESSAGE}" | sed 's/^/ /' + +cat < - -#include - -#include - -namespace ml { -namespace torch { - -CModelGraphValidator::SResult CModelGraphValidator::validate(const ::torch::jit::Module& module) { - - TStringSet observedOps; - std::size_t nodeCount{0}; - collectModuleOps(module, observedOps, nodeCount); - - if (nodeCount > MAX_NODE_COUNT) { - LOG_ERROR(<< "Model graph is too large: " << nodeCount - << " nodes exceeds limit of " << MAX_NODE_COUNT); - return {false, {}, {}, nodeCount}; - } - - LOG_DEBUG(<< "Model graph contains " << observedOps.size() - << " distinct operations across " << nodeCount << " nodes"); - for (const auto& op : observedOps) { - LOG_DEBUG(<< " observed op: " << op); - } - - auto result = validate(observedOps, CSupportedOperations::ALLOWED_OPERATIONS, - CSupportedOperations::FORBIDDEN_OPERATIONS); - result.s_NodeCount = nodeCount; - return result; -} - -CModelGraphValidator::SResult -CModelGraphValidator::validate(const TStringSet& observedOps, - const std::unordered_set& allowedOps, - const std::unordered_set& forbiddenOps) { - - SResult result; - - // Two-pass check: forbidden ops first, then unrecognised. This lets us - // fail fast when a known-dangerous operation is present and avoids the - // cost of scanning for unrecognised ops on a model we will reject anyway. - for (const auto& op : observedOps) { - if (forbiddenOps.contains(op)) { - result.s_IsValid = false; - result.s_ForbiddenOps.push_back(op); - } - } - - if (result.s_ForbiddenOps.empty()) { - for (const auto& op : observedOps) { - if (allowedOps.contains(op) == false) { - result.s_IsValid = false; - result.s_UnrecognisedOps.push_back(op); - } - } - } - - std::sort(result.s_ForbiddenOps.begin(), result.s_ForbiddenOps.end()); - std::sort(result.s_UnrecognisedOps.begin(), result.s_UnrecognisedOps.end()); - - return result; -} - -void CModelGraphValidator::collectBlockOps(const ::torch::jit::Block& block, - TStringSet& ops, - std::size_t& nodeCount) { - for (const auto* node : block.nodes()) { - if (++nodeCount > MAX_NODE_COUNT) { - return; - } - ops.emplace(node->kind().toQualString()); - for (const auto* subBlock : node->blocks()) { - collectBlockOps(*subBlock, ops, nodeCount); - if (nodeCount > MAX_NODE_COUNT) { - return; - } - } - } -} - -void CModelGraphValidator::collectModuleOps(const ::torch::jit::Module& module, - TStringSet& ops, - std::size_t& nodeCount) { - for (const auto& method : module.get_methods()) { - // Inline all method calls so that operations hidden behind - // prim::CallMethod are surfaced. After inlining, any remaining - // prim::CallMethod indicates a call that could not be resolved - // statically and will be flagged as unrecognised. - auto graph = method.graph()->copy(); - ::torch::jit::Inline(*graph); - collectBlockOps(*graph->block(), ops, nodeCount); - if (nodeCount > MAX_NODE_COUNT) { - return; - } - } -} -} -} diff --git a/bin/pytorch_inference/CModelGraphValidator.h b/bin/pytorch_inference/CModelGraphValidator.h deleted file mode 100644 index 2c589dab5..000000000 --- a/bin/pytorch_inference/CModelGraphValidator.h +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the following additional limitation. Functionality enabled by the - * files subject to the Elastic License 2.0 may only be used in production when - * invoked by an Elasticsearch process with a license key installed that permits - * use of machine learning features. You may not use this file except in - * compliance with the Elastic License 2.0 and the foregoing additional - * limitation. - */ - -#ifndef INCLUDED_ml_torch_CModelGraphValidator_h -#define INCLUDED_ml_torch_CModelGraphValidator_h - -#include - -#include -#include -#include -#include - -namespace ml { -namespace torch { - -//! \brief -//! Validates TorchScript model computation graphs against a set of -//! allowed operations. -//! -//! DESCRIPTION:\n -//! Provides defense-in-depth by statically inspecting the TorchScript -//! graph of a loaded model and rejecting any model that contains -//! operations not present in the allowlist derived from supported -//! transformer architectures. -//! -//! IMPLEMENTATION DECISIONS:\n -//! The validation walks all methods of the module and its submodules -//! recursively, collecting every distinct operation. Any operation -//! that appears in the forbidden set causes immediate rejection. -//! Any operation not in the allowed set is collected and reported. -//! This ensures that even operations buried in helper methods or -//! nested submodules are inspected. -//! -class CModelGraphValidator { -public: - using TStringSet = std::unordered_set; - using TStringVec = std::vector; - - //! Upper bound on the number of graph nodes we are willing to inspect. - //! Transformer models typically have O(10k) nodes after inlining; a - //! limit of 1M provides generous headroom while preventing a - //! pathologically large graph from consuming unbounded memory or CPU. - static constexpr std::size_t MAX_NODE_COUNT{1000000}; - - //! Result of validating a model graph. - struct SResult { - bool s_IsValid{true}; - TStringVec s_ForbiddenOps; - TStringVec s_UnrecognisedOps; - std::size_t s_NodeCount{0}; - }; - -public: - //! Validate the computation graph of the given module against the - //! supported operation allowlist. Recursively inspects all methods - //! across all submodules. - static SResult validate(const ::torch::jit::Module& module); - - //! Validate a pre-collected set of operation names. Useful for - //! unit testing the matching logic without requiring a real model. - static SResult validate(const TStringSet& observedOps, - const std::unordered_set& allowedOps, - const std::unordered_set& forbiddenOps); - -private: - //! Collect all operation names from a block, recursing into sub-blocks. - static void collectBlockOps(const ::torch::jit::Block& block, - TStringSet& ops, - std::size_t& nodeCount); - - //! Inline all method calls and collect ops from the flattened graph. - //! After inlining, prim::CallMethod should not appear; if it does, - //! the call could not be resolved statically and is treated as - //! unrecognised. - static void collectModuleOps(const ::torch::jit::Module& module, - TStringSet& ops, - std::size_t& nodeCount); -}; -} -} - -#endif // INCLUDED_ml_torch_CModelGraphValidator_h diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc deleted file mode 100644 index 47fc60068..000000000 --- a/bin/pytorch_inference/CSupportedOperations.cc +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the following additional limitation. Functionality enabled by the - * files subject to the Elastic License 2.0 may only be used in production when - * invoked by an Elasticsearch process with a license key installed that permits - * use of machine learning features. You may not use this file except in - * compliance with the Elastic License 2.0 and the foregoing additional - * limitation. - */ - -#include "CSupportedOperations.h" - -namespace ml { -namespace torch { - -using namespace std::string_view_literals; - -const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERATIONS = { - // Arbitrary memory access — enables heap scanning, address leaks, and - // ROP chain construction. - "aten::as_strided"sv, - "aten::from_file"sv, - "aten::save"sv, - // After graph inlining, method and function calls should be resolved. - // Their presence indicates an opaque call that cannot be validated. - "prim::CallFunction"sv, - "prim::CallMethod"sv, -}; - -// Generated by dev-tools/extract_model_ops/extract_model_ops.py against PyTorch 2.7.1. -// Reference models: bert-base-uncased, roberta-base, distilbert-base-uncased, -// google/electra-small-discriminator, microsoft/mpnet-base, -// microsoft/deberta-base, facebook/dpr-ctx_encoder-single-nq-base, -// google/mobilebert-uncased, xlm-roberta-base, elastic/bge-m3, -// elastic/distilbert-base-{cased,uncased}-finetuned-conll03-english, -// elastic/eis-elser-v2, elastic/elser-v2, elastic/hugging-face-elser, -// elastic/multilingual-e5-small-optimized, elastic/splade-v3, -// elastic/test-elser-v2. -// Additional ops from Elasticsearch integration test models -// (PyTorchModelIT, TextExpansionQueryIT, TextEmbeddingQueryIT). -// Quantized operations from dynamically quantized variants of the above -// models (torch.quantization.quantize_dynamic on nn.Linear layers). -const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATIONS = { - // aten operations (core tensor computations) - "aten::Int"sv, - "aten::IntImplicit"sv, - "aten::ScalarImplicit"sv, - "aten::__and__"sv, - "aten::abs"sv, - "aten::add"sv, - "aten::add_"sv, - "aten::arange"sv, - "aten::bitwise_not"sv, - "aten::cat"sv, - "aten::chunk"sv, - "aten::clamp"sv, - "aten::contiguous"sv, - "aten::cumsum"sv, - "aten::div"sv, - "aten::div_"sv, - "aten::dropout"sv, - "aten::embedding"sv, - "aten::expand"sv, - "aten::full_like"sv, - "aten::gather"sv, - "aten::ge"sv, - "aten::gelu"sv, - "aten::hash"sv, - "aten::index"sv, - "aten::index_put_"sv, - "aten::layer_norm"sv, - "aten::len"sv, - "aten::linear"sv, - "aten::log"sv, - "aten::lt"sv, - "aten::manual_seed"sv, - "aten::masked_fill"sv, - "aten::matmul"sv, - "aten::max"sv, - "aten::mean"sv, - "aten::min"sv, - "aten::mul"sv, - "aten::mul_"sv, - "aten::ne"sv, - "aten::neg"sv, - "aten::new_ones"sv, - "aten::ones"sv, - "aten::pad"sv, - "aten::permute"sv, - "aten::pow"sv, - "aten::rand"sv, - "aten::relu"sv, - "aten::repeat"sv, - "aten::reshape"sv, - "aten::rsub"sv, - "aten::scaled_dot_product_attention"sv, - "aten::select"sv, - "aten::size"sv, - "aten::slice"sv, - "aten::softmax"sv, - "aten::sqrt"sv, - "aten::squeeze"sv, - "aten::str"sv, - "aten::sub"sv, - "aten::tanh"sv, - "aten::tensor"sv, - "aten::to"sv, - "aten::transpose"sv, - "aten::type_as"sv, - "aten::unsqueeze"sv, - "aten::view"sv, - "aten::where"sv, - "aten::zeros"sv, - // prim operations (TorchScript graph infrastructure) - "prim::Constant"sv, - "prim::DictConstruct"sv, - "prim::GetAttr"sv, - "prim::If"sv, - "prim::ListConstruct"sv, - "prim::ListUnpack"sv, - "prim::Loop"sv, - "prim::NumToTensor"sv, - "prim::TupleConstruct"sv, - "prim::TupleUnpack"sv, - "prim::device"sv, - "prim::dtype"sv, - "prim::max"sv, - "prim::min"sv, - // quantized operations (dynamically quantized models, e.g. ELSER v2) - "quantized::linear_dynamic"sv, -}; -} -} diff --git a/bin/pytorch_inference/CSupportedOperations.h b/bin/pytorch_inference/CSupportedOperations.h deleted file mode 100644 index 3719bec80..000000000 --- a/bin/pytorch_inference/CSupportedOperations.h +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the following additional limitation. Functionality enabled by the - * files subject to the Elastic License 2.0 may only be used in production when - * invoked by an Elasticsearch process with a license key installed that permits - * use of machine learning features. You may not use this file except in - * compliance with the Elastic License 2.0 and the foregoing additional - * limitation. - */ - -#ifndef INCLUDED_ml_torch_CSupportedOperations_h -#define INCLUDED_ml_torch_CSupportedOperations_h - -#include -#include - -namespace ml { -namespace torch { - -//! \brief -//! Flat allowlist of TorchScript operations observed across all -//! supported transformer architectures (BERT, RoBERTa, DistilBERT, -//! ELECTRA, MPNet, DeBERTa, BART, DPR, MobileBERT, XLM-RoBERTa). -//! -//! DESCRIPTION:\n -//! Generated by tracing reference HuggingFace models with -//! dev-tools/extract_model_ops/extract_model_ops.py and collecting the union of all -//! operations from the inlined forward() computation graphs. -//! -//! IMPLEMENTATION DECISIONS:\n -//! Stored as a compile-time data structure rather than an external -//! config file to avoid runtime loading failures and to keep the -//! security boundary self-contained. The list should be regenerated -//! whenever the set of supported architectures changes or when -//! upgrading the PyTorch version. -//! -class CSupportedOperations { -public: - using TStringViewSet = std::unordered_set; - - //! Operations explicitly forbidden regardless of the allowlist. - //! - //! The forbidden list is checked separately from (and takes precedence - //! over) the allowed list. This two-tier approach provides: - //! - //! 1. Stable, targeted error messages for known-dangerous operations - //! (e.g. "model contains forbidden operation: aten::save") rather - //! than the generic "unrecognised operation" that the allowlist - //! would produce. This helps model authors diagnose rejections. - //! - //! 2. A safety net against accidental allowlist expansion. If a - //! future PyTorch upgrade or new architecture inadvertently adds - //! a dangerous op to the allowed set, the forbidden list still - //! blocks it. The forbidden check is independent of regeneration. - //! - //! 3. Defence-in-depth: two independent mechanisms must both agree - //! before an operation is permitted, reducing the risk of a - //! single-point allowlist error opening an attack vector. - static const TStringViewSet FORBIDDEN_OPERATIONS; - - //! Union of all TorchScript operations observed in supported architectures. - static const TStringViewSet ALLOWED_OPERATIONS; -}; -} -} - -#endif // INCLUDED_ml_torch_CSupportedOperations_h diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index 4a7d2dde6..00adee1df 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -27,7 +27,6 @@ #include "CBufferedIStreamAdapter.h" #include "CCmdLineParser.h" #include "CCommandParser.h" -#include "CModelGraphValidator.h" #include "CResultWriter.h" #include "CThreadSettings.h" @@ -43,35 +42,24 @@ #include namespace { +// Add more forbidden ops here if needed +const std::unordered_set FORBIDDEN_OPERATIONS = {"aten::from_file", "aten::save"}; + void verifySafeModel(const torch::jit::script::Module& module_) { try { - auto result = ml::torch::CModelGraphValidator::validate(module_); - - if (result.s_ForbiddenOps.empty() == false) { - std::string ops = ml::core::CStringUtils::join(result.s_ForbiddenOps, ", "); - HANDLE_FATAL(<< "Model contains forbidden operations: " << ops); - } - - if (result.s_UnrecognisedOps.empty() == false) { - std::string ops = ml::core::CStringUtils::join(result.s_UnrecognisedOps, ", "); - HANDLE_FATAL(<< "Model graph does not match any supported architecture. " - << "Unrecognised operations: " << ops); - } - - if (result.s_NodeCount > ml::torch::CModelGraphValidator::MAX_NODE_COUNT) { - HANDLE_FATAL(<< "Model graph is too large: " << result.s_NodeCount << " nodes exceeds limit of " - << ml::torch::CModelGraphValidator::MAX_NODE_COUNT); - } - - if (result.s_IsValid == false) { - HANDLE_FATAL(<< "Model graph validation failed"); + const auto method = module_.get_method("forward"); + for (const auto graph = method.graph(); const auto& node : graph->nodes()) { + if (const std::string opName = node->kind().toQualString(); + FORBIDDEN_OPERATIONS.contains(opName)) { + HANDLE_FATAL(<< "Loading the inference process failed because it contains forbidden operation: " + << opName); + } } - - LOG_DEBUG(<< "Model verified: " << result.s_NodeCount - << " nodes, all operations match supported architectures."); } catch (const c10::Error& e) { - HANDLE_FATAL(<< "Model graph validation failed: " << e.what()); + LOG_FATAL(<< "Failed to get forward method: " << e.what()); } + + LOG_DEBUG(<< "Model verified: no forbidden operations detected."); } } diff --git a/bin/pytorch_inference/unittest/CCommandParserTest.cc b/bin/pytorch_inference/unittest/CCommandParserTest.cc index 5c7e7e4fd..7dcf6a7ef 100644 --- a/bin/pytorch_inference/unittest/CCommandParserTest.cc +++ b/bin/pytorch_inference/unittest/CCommandParserTest.cc @@ -9,7 +9,7 @@ * limitation. */ -#include +#include "../CCommandParser.h" #include diff --git a/bin/pytorch_inference/unittest/CMakeLists.txt b/bin/pytorch_inference/unittest/CMakeLists.txt index fe3c544a5..dd5394492 100644 --- a/bin/pytorch_inference/unittest/CMakeLists.txt +++ b/bin/pytorch_inference/unittest/CMakeLists.txt @@ -14,7 +14,6 @@ project("ML pytorch_inference unit tests") set (SRCS Main.cc CCommandParserTest.cc - CModelGraphValidatorTest.cc CResultWriterTest.cc CThreadSettingsTest.cc ) @@ -34,5 +33,3 @@ set(ML_LINK_LIBRARIES ) ml_add_test_executable(pytorch_inference ${SRCS}) - -target_include_directories(ml_test_pytorch_inference PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..) diff --git a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc deleted file mode 100644 index 7818e88f0..000000000 --- a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc +++ /dev/null @@ -1,483 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the following additional limitation. Functionality enabled by the - * files subject to the Elastic License 2.0 may only be used in production when - * invoked by an Elasticsearch process with a license key installed that permits - * use of machine learning features. You may not use this file except in - * compliance with the Elastic License 2.0 and the foregoing additional - * limitation. - */ - -#include - -#include - -#include -#include - -#include - -#include -#include -#include -#include -#include - -using namespace ml::torch; -using TStringSet = CModelGraphValidator::TStringSet; -using TStringViewSet = std::unordered_set; - -BOOST_AUTO_TEST_SUITE(CModelGraphValidatorTest) - -BOOST_AUTO_TEST_CASE(testAllAllowedOpsPass) { - // A model using only allowed ops should pass validation. - TStringSet observed{"aten::linear", "aten::layer_norm", "aten::gelu", - "aten::embedding", "prim::Constant", "prim::GetAttr"}; - - auto result = CModelGraphValidator::validate( - observed, CSupportedOperations::ALLOWED_OPERATIONS, - CSupportedOperations::FORBIDDEN_OPERATIONS); - - BOOST_REQUIRE(result.s_IsValid); - BOOST_REQUIRE(result.s_ForbiddenOps.empty()); - BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); -} - -BOOST_AUTO_TEST_CASE(testEmptyGraphPasses) { - TStringSet observed; - - auto result = CModelGraphValidator::validate( - observed, CSupportedOperations::ALLOWED_OPERATIONS, - CSupportedOperations::FORBIDDEN_OPERATIONS); - - BOOST_REQUIRE(result.s_IsValid); - BOOST_REQUIRE(result.s_ForbiddenOps.empty()); - BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); -} - -BOOST_AUTO_TEST_CASE(testForbiddenOpsRejected) { - TStringSet observed{"aten::linear", "aten::from_file", "prim::Constant"}; - - auto result = CModelGraphValidator::validate( - observed, CSupportedOperations::ALLOWED_OPERATIONS, - CSupportedOperations::FORBIDDEN_OPERATIONS); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); - BOOST_REQUIRE_EQUAL("aten::from_file", result.s_ForbiddenOps[0]); - BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); -} - -BOOST_AUTO_TEST_CASE(testMultipleForbiddenOps) { - TStringSet observed{"aten::from_file", "aten::save"}; - - auto result = CModelGraphValidator::validate( - observed, CSupportedOperations::ALLOWED_OPERATIONS, - CSupportedOperations::FORBIDDEN_OPERATIONS); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE_EQUAL(2, result.s_ForbiddenOps.size()); - BOOST_REQUIRE_EQUAL("aten::from_file", result.s_ForbiddenOps[0]); - BOOST_REQUIRE_EQUAL("aten::save", result.s_ForbiddenOps[1]); -} - -BOOST_AUTO_TEST_CASE(testUnrecognisedOpsRejected) { - TStringSet observed{"aten::linear", "custom::evil_op", "prim::Constant"}; - - auto result = CModelGraphValidator::validate( - observed, CSupportedOperations::ALLOWED_OPERATIONS, - CSupportedOperations::FORBIDDEN_OPERATIONS); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE(result.s_ForbiddenOps.empty()); - BOOST_REQUIRE_EQUAL(1, result.s_UnrecognisedOps.size()); - BOOST_REQUIRE_EQUAL("custom::evil_op", result.s_UnrecognisedOps[0]); -} - -BOOST_AUTO_TEST_CASE(testMixedForbiddenAndUnrecognised) { - // When forbidden ops are present, the validator short-circuits and - // does not report unrecognised ops — we reject immediately. - TStringSet observed{"aten::save", "custom::backdoor", "aten::linear"}; - - auto result = CModelGraphValidator::validate( - observed, CSupportedOperations::ALLOWED_OPERATIONS, - CSupportedOperations::FORBIDDEN_OPERATIONS); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); - BOOST_REQUIRE_EQUAL("aten::save", result.s_ForbiddenOps[0]); - BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); -} - -BOOST_AUTO_TEST_CASE(testResultsSorted) { - TStringSet observed{"zzz::unknown", "aaa::unknown", "mmm::unknown"}; - - auto result = CModelGraphValidator::validate( - observed, CSupportedOperations::ALLOWED_OPERATIONS, - CSupportedOperations::FORBIDDEN_OPERATIONS); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE_EQUAL(3, result.s_UnrecognisedOps.size()); - BOOST_REQUIRE_EQUAL("aaa::unknown", result.s_UnrecognisedOps[0]); - BOOST_REQUIRE_EQUAL("mmm::unknown", result.s_UnrecognisedOps[1]); - BOOST_REQUIRE_EQUAL("zzz::unknown", result.s_UnrecognisedOps[2]); -} - -BOOST_AUTO_TEST_CASE(testTypicalBertOps) { - // Simulate a realistic BERT-like op set. - TStringSet observed{"aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::div", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::gelu", - "aten::ge", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::masked_fill", - "aten::matmul", - "aten::mul", - "aten::new_ones", - "aten::permute", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::size", - "aten::slice", - "aten::softmax", - "aten::tanh", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::If", - "prim::ListConstruct", - "prim::NumToTensor", - "prim::TupleConstruct"}; - - auto result = CModelGraphValidator::validate( - observed, CSupportedOperations::ALLOWED_OPERATIONS, - CSupportedOperations::FORBIDDEN_OPERATIONS); - - BOOST_REQUIRE(result.s_IsValid); - BOOST_REQUIRE(result.s_ForbiddenOps.empty()); - BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); -} - -BOOST_AUTO_TEST_CASE(testCustomAllowlistAndForbiddenList) { - // Verify the three-argument overload works with arbitrary lists. - TStringViewSet allowed{"op::a", "op::b", "op::c"}; - TStringViewSet forbidden{"op::bad"}; - TStringSet observed{"op::a", "op::b"}; - - auto result = CModelGraphValidator::validate(observed, allowed, forbidden); - BOOST_REQUIRE(result.s_IsValid); - - observed.emplace("op::bad"); - result = CModelGraphValidator::validate(observed, allowed, forbidden); - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); - - observed.erase("op::bad"); - observed.emplace("op::unknown"); - result = CModelGraphValidator::validate(observed, allowed, forbidden); - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE_EQUAL(1, result.s_UnrecognisedOps.size()); -} - -BOOST_AUTO_TEST_CASE(testCallMethodForbiddenAfterInlining) { - // prim::CallMethod must not appear after graph inlining; its presence - // means a method call could not be resolved and the graph cannot be - // fully validated. - TStringSet observed{"aten::linear", "prim::Constant", "prim::CallMethod"}; - - auto result = CModelGraphValidator::validate( - observed, CSupportedOperations::ALLOWED_OPERATIONS, - CSupportedOperations::FORBIDDEN_OPERATIONS); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); - BOOST_REQUIRE_EQUAL("prim::CallMethod", result.s_ForbiddenOps[0]); - BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); -} - -BOOST_AUTO_TEST_CASE(testCallFunctionForbiddenAfterInlining) { - TStringSet observed{"aten::linear", "prim::CallFunction"}; - - auto result = CModelGraphValidator::validate( - observed, CSupportedOperations::ALLOWED_OPERATIONS, - CSupportedOperations::FORBIDDEN_OPERATIONS); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); - BOOST_REQUIRE_EQUAL("prim::CallFunction", result.s_ForbiddenOps[0]); -} - -BOOST_AUTO_TEST_CASE(testMaxNodeCountConstant) { - BOOST_REQUIRE(CModelGraphValidator::MAX_NODE_COUNT > 0); - BOOST_REQUIRE_EQUAL(std::size_t{1000000}, CModelGraphValidator::MAX_NODE_COUNT); -} - -BOOST_AUTO_TEST_CASE(testForbiddenOpAlsoInAllowlist) { - // If an op appears in both forbidden and allowed, forbidden takes precedence. - TStringViewSet allowed{"aten::from_file", "aten::linear"}; - TStringViewSet forbidden{"aten::from_file"}; - TStringSet observed{"aten::from_file", "aten::linear"}; - - auto result = CModelGraphValidator::validate(observed, allowed, forbidden); - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); - BOOST_REQUIRE_EQUAL("aten::from_file", result.s_ForbiddenOps[0]); -} - -// --- Integration tests using real TorchScript modules --- - -BOOST_AUTO_TEST_CASE(testValidModuleWithAllowedOps) { - // A simple module using only aten::add and aten::mul, both of which - // are in the allowed set. - ::torch::jit::Module m("__torch__.ValidModel"); - m.define(R"( - def forward(self, x: Tensor) -> Tensor: - return x + x * x - )"); - - auto result = CModelGraphValidator::validate(m); - - BOOST_REQUIRE(result.s_IsValid); - BOOST_REQUIRE(result.s_ForbiddenOps.empty()); - BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); - BOOST_REQUIRE(result.s_NodeCount > 0); -} - -BOOST_AUTO_TEST_CASE(testModuleWithUnrecognisedOps) { - // torch.sin is not in the transformer allowlist. - ::torch::jit::Module m("__torch__.UnknownOps"); - m.define(R"( - def forward(self, x: Tensor) -> Tensor: - return torch.sin(x) - )"); - - auto result = CModelGraphValidator::validate(m); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE(result.s_ForbiddenOps.empty()); - BOOST_REQUIRE(result.s_UnrecognisedOps.empty() == false); - bool foundSin = false; - for (const auto& op : result.s_UnrecognisedOps) { - if (op == "aten::sin") { - foundSin = true; - } - } - BOOST_REQUIRE(foundSin); -} - -BOOST_AUTO_TEST_CASE(testModuleNodeCountPopulated) { - ::torch::jit::Module m("__torch__.NodeCount"); - m.define(R"( - def forward(self, x: Tensor) -> Tensor: - a = x + x - b = a * a - c = b - a - return c - )"); - - auto result = CModelGraphValidator::validate(m); - - BOOST_REQUIRE(result.s_NodeCount > 0); -} - -BOOST_AUTO_TEST_CASE(testModuleWithSubmoduleInlines) { - // Create a parent module with a child submodule. After inlining, - // the child's operations should be visible and validated. - ::torch::jit::Module child("__torch__.Child"); - child.define(R"( - def forward(self, x: Tensor) -> Tensor: - return torch.sin(x) - )"); - - ::torch::jit::Module parent("__torch__.Parent"); - parent.register_module("child", child); - parent.define(R"( - def forward(self, x: Tensor) -> Tensor: - return self.child.forward(x) + x - )"); - - auto result = CModelGraphValidator::validate(parent); - - BOOST_REQUIRE(result.s_IsValid == false); - bool foundSin = false; - for (const auto& op : result.s_UnrecognisedOps) { - if (op == "aten::sin") { - foundSin = true; - } - } - BOOST_REQUIRE(foundSin); -} - -// --- Integration tests with malicious .pt model fixtures --- -// -// These load real TorchScript models that simulate attack vectors. -// The .pt files are generated by testfiles/generate_malicious_models.py. - -namespace { -bool hasForbiddenOp(const CModelGraphValidator::SResult& result, const std::string& op) { - return std::find(result.s_ForbiddenOps.begin(), result.s_ForbiddenOps.end(), - op) != result.s_ForbiddenOps.end(); -} - -bool hasUnrecognisedOp(const CModelGraphValidator::SResult& result, const std::string& op) { - return std::find(result.s_UnrecognisedOps.begin(), result.s_UnrecognisedOps.end(), - op) != result.s_UnrecognisedOps.end(); -} -} - -BOOST_AUTO_TEST_CASE(testMaliciousFileReader) { - // A model that uses aten::from_file to read arbitrary files. - auto module = ::torch::jit::load("testfiles/malicious_models/malicious_file_reader.pt"); - auto result = CModelGraphValidator::validate(module); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE(hasForbiddenOp(result, "aten::from_file")); -} - -BOOST_AUTO_TEST_CASE(testMaliciousMixedFileReader) { - // A model that mixes allowed ops (aten::add) with a forbidden - // aten::from_file. The entire model must be rejected. - auto module = ::torch::jit::load("testfiles/malicious_models/malicious_mixed_file_reader.pt"); - auto result = CModelGraphValidator::validate(module); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE(hasForbiddenOp(result, "aten::from_file")); - BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); -} - -BOOST_AUTO_TEST_CASE(testMaliciousHiddenInSubmodule) { - // Unrecognised ops buried three levels deep in nested submodules. - // The validator must inline through all submodules to find them. - auto module = ::torch::jit::load("testfiles/malicious_models/malicious_hidden_in_submodule.pt"); - auto result = CModelGraphValidator::validate(module); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE(result.s_ForbiddenOps.empty()); - BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin")); -} - -BOOST_AUTO_TEST_CASE(testMaliciousConditionalBranch) { - // An unrecognised op hidden inside a conditional branch. The - // validator must recurse into prim::If blocks to detect it. - auto module = ::torch::jit::load("testfiles/malicious_models/malicious_conditional.pt"); - auto result = CModelGraphValidator::validate(module); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin")); -} - -BOOST_AUTO_TEST_CASE(testMaliciousManyUnrecognisedOps) { - // A model using many different unrecognised ops (sin, cos, tan, exp). - auto module = ::torch::jit::load("testfiles/malicious_models/malicious_many_unrecognised.pt"); - auto result = CModelGraphValidator::validate(module); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE(result.s_ForbiddenOps.empty()); - BOOST_REQUIRE(result.s_UnrecognisedOps.size() >= 4); - BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin")); - BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::cos")); - BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::tan")); - BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::exp")); -} - -BOOST_AUTO_TEST_CASE(testMaliciousFileReaderInSubmodule) { - // The forbidden aten::from_file is hidden inside a submodule. - // After inlining, the validator must still detect it. - auto module = ::torch::jit::load("testfiles/malicious_models/malicious_file_reader_in_submodule.pt"); - auto result = CModelGraphValidator::validate(module); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE(hasForbiddenOp(result, "aten::from_file")); -} - -// --- Sandbox2 attack models --- -// -// These reproduce real-world attack vectors that exploit torch.as_strided -// to read out-of-bounds heap memory, leak libtorch addresses, and build -// ROP chains that call mprotect + shellcode to write arbitrary files. -// The graph validator must reject them because aten::as_strided is in -// the forbidden operations list. - -BOOST_AUTO_TEST_CASE(testMaliciousHeapLeak) { - // A model that uses torch.as_strided with a malicious storage offset - // to scan the heap for libtorch pointers and leak their addresses - // via an assertion message. - auto module = ::torch::jit::load("testfiles/malicious_models/malicious_heap_leak.pt"); - auto result = CModelGraphValidator::validate(module); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE(hasForbiddenOp(result, "aten::as_strided")); -} - -BOOST_AUTO_TEST_CASE(testMaliciousRopExploit) { - // A model that extends the heap-leak technique to overwrite function - // pointers and build a ROP chain: mprotect a heap page as executable, - // then jump to shellcode that writes files to disk. - auto module = ::torch::jit::load("testfiles/malicious_models/malicious_rop_exploit.pt"); - auto result = CModelGraphValidator::validate(module); - - BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE(hasForbiddenOp(result, "aten::as_strided")); -} - -// --- Allowlist drift detection --- -// -// Validates that ALLOWED_OPERATIONS covers every operation observed in -// the reference HuggingFace models. The golden file is generated by -// dev-tools/extract_model_ops/extract_model_ops.py --golden and should -// be regenerated whenever PyTorch is upgraded or the set of supported -// architectures changes. - -BOOST_AUTO_TEST_CASE(testAllowlistCoversReferenceModels) { - std::ifstream file("testfiles/reference_model_ops.json"); - BOOST_REQUIRE_MESSAGE(file.is_open(), - "Could not open testfiles/reference_model_ops.json — " - "regenerate with: python3 dev-tools/extract_model_ops/" - "extract_model_ops.py --golden " - "bin/pytorch_inference/unittest/testfiles/reference_model_ops.json"); - - std::ostringstream buf; - buf << file.rdbuf(); - auto root = boost::json::parse(buf.str()).as_object(); - - auto& models = root.at("models").as_object(); - BOOST_REQUIRE_MESSAGE(models.size() > 0, "Golden file contains no models"); - - const auto& allowed = CSupportedOperations::ALLOWED_OPERATIONS; - const auto& forbidden = CSupportedOperations::FORBIDDEN_OPERATIONS; - - for (const auto & [ arch, entry ] : models) { - const auto& info = entry.as_object(); - const auto& ops = info.at("ops").as_array(); - std::string modelId{info.at("model_id").as_string()}; - - for (const auto& opVal : ops) { - std::string op{opVal.as_string()}; - - BOOST_CHECK_MESSAGE(forbidden.count(op) == 0, - arch << " (" << modelId << "): op " << op << " is in FORBIDDEN_OPERATIONS — a legitimate model " - << "should not use forbidden ops"); - - BOOST_CHECK_MESSAGE(allowed.count(op) == 1, - arch << " (" << modelId << "): op " << op << " is not in ALLOWED_OPERATIONS — update the allowlist " - << "or check if this op was introduced by a PyTorch upgrade"); - } - } -} - -BOOST_AUTO_TEST_SUITE_END() diff --git a/bin/pytorch_inference/unittest/CResultWriterTest.cc b/bin/pytorch_inference/unittest/CResultWriterTest.cc index 7803bbc39..97b99038a 100644 --- a/bin/pytorch_inference/unittest/CResultWriterTest.cc +++ b/bin/pytorch_inference/unittest/CResultWriterTest.cc @@ -9,9 +9,9 @@ * limitation. */ -#include +#include "../CResultWriter.h" -#include +#include "../CThreadSettings.h" #include #include diff --git a/bin/pytorch_inference/unittest/CThreadSettingsTest.cc b/bin/pytorch_inference/unittest/CThreadSettingsTest.cc index 759affb02..8ab8d03d2 100644 --- a/bin/pytorch_inference/unittest/CThreadSettingsTest.cc +++ b/bin/pytorch_inference/unittest/CThreadSettingsTest.cc @@ -9,7 +9,7 @@ * limitation. */ -#include +#include "../CThreadSettings.h" #include diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_conditional.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_conditional.pt deleted file mode 100644 index 114707e6a..000000000 Binary files a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_conditional.pt and /dev/null differ diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader.pt deleted file mode 100644 index fb0b26f46..000000000 Binary files a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader.pt and /dev/null differ diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader_in_submodule.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader_in_submodule.pt deleted file mode 100644 index 4d6f6328b..000000000 Binary files a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader_in_submodule.pt and /dev/null differ diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_heap_leak.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_heap_leak.pt deleted file mode 100644 index 3458ab76a..000000000 Binary files a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_heap_leak.pt and /dev/null differ diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_hidden_in_submodule.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_hidden_in_submodule.pt deleted file mode 100644 index 39104c647..000000000 Binary files a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_hidden_in_submodule.pt and /dev/null differ diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_many_unrecognised.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_many_unrecognised.pt deleted file mode 100644 index 68639503a..000000000 Binary files a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_many_unrecognised.pt and /dev/null differ diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_mixed_file_reader.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_mixed_file_reader.pt deleted file mode 100644 index 78b8c47c4..000000000 Binary files a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_mixed_file_reader.pt and /dev/null differ diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_rop_exploit.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_rop_exploit.pt deleted file mode 100644 index 08beafc14..000000000 Binary files a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_rop_exploit.pt and /dev/null differ diff --git a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json deleted file mode 100644 index 364d49f86..000000000 --- a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json +++ /dev/null @@ -1,722 +0,0 @@ -{ - "pytorch_version": "2.7.1", - "models": { - "bert": { - "model_id": "bert-base-uncased", - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::gather", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::select", - "aten::size", - "aten::slice", - "aten::tanh", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - }, - "deberta": { - "model_id": "microsoft/deberta-base", - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::add", - "aten::add_", - "aten::arange", - "aten::bitwise_not", - "aten::chunk", - "aten::clamp", - "aten::contiguous", - "aten::div", - "aten::div_", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::gather", - "aten::gelu", - "aten::linear", - "aten::masked_fill", - "aten::matmul", - "aten::mean", - "aten::mul", - "aten::ne", - "aten::neg", - "aten::permute", - "aten::pow", - "aten::repeat", - "aten::rsub", - "aten::select", - "aten::size", - "aten::slice", - "aten::softmax", - "aten::sqrt", - "aten::squeeze", - "aten::sub", - "aten::tensor", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::If", - "prim::ListConstruct", - "prim::ListUnpack", - "prim::TupleConstruct", - "prim::TupleUnpack", - "prim::device", - "prim::max", - "prim::min" - ] - }, - "distilbert": { - "model_id": "distilbert-base-uncased", - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::size", - "aten::slice", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - }, - "dpr": { - "model_id": "facebook/dpr-ctx_encoder-single-nq-base", - "ops": [ - "aten::Int", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::select", - "aten::size", - "aten::slice", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "aten::zeros", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - }, - "elastic-bge-m3": { - "model_id": "elastic/bge-m3", - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::cumsum", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::gather", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::mul", - "aten::ne", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::select", - "aten::size", - "aten::slice", - "aten::tanh", - "aten::to", - "aten::transpose", - "aten::type_as", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - }, - "elastic-distilbert-cased-ner": { - "model_id": "elastic/distilbert-base-cased-finetuned-conll03-english", - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::size", - "aten::slice", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - }, - "elastic-distilbert-uncased-ner": { - "model_id": "elastic/distilbert-base-uncased-finetuned-conll03-english", - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::size", - "aten::slice", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - }, - "elastic-eis-elser-v2": { - "model_id": "elastic/eis-elser-v2", - "quantized": false, - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::gather", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::select", - "aten::size", - "aten::slice", - "aten::tanh", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - }, - "elastic-elser-v2": { - "model_id": "elastic/elser-v2", - "quantized": false, - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::gather", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::select", - "aten::size", - "aten::slice", - "aten::tanh", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - }, - "elastic-elser-v2-quantized": { - "model_id": "elastic/elser-v2", - "quantized": true, - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::gather", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::mul_", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::select", - "aten::size", - "aten::slice", - "aten::tanh", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor", - "quantized::linear_dynamic" - ] - }, - "elastic-hugging-face-elser": { - "model_id": "elastic/hugging-face-elser", - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::gather", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::select", - "aten::size", - "aten::slice", - "aten::tanh", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - }, - "elastic-multilingual-e5-small-optimized": { - "model_id": "elastic/multilingual-e5-small-optimized", - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::gather", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::select", - "aten::size", - "aten::slice", - "aten::tanh", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - }, - "elastic-splade-v3": { - "model_id": "elastic/splade-v3", - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::gather", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::select", - "aten::size", - "aten::slice", - "aten::tanh", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - }, - "elastic-test-elser-v2": { - "model_id": "elastic/test-elser-v2", - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::gather", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::select", - "aten::size", - "aten::slice", - "aten::tanh", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - }, - "electra": { - "model_id": "google/electra-small-discriminator", - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::gather", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::size", - "aten::slice", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - }, - "mobilebert": { - "model_id": "google/mobilebert-uncased", - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::cat", - "aten::contiguous", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::ge", - "aten::index", - "aten::linear", - "aten::mul", - "aten::new_ones", - "aten::pad", - "aten::relu", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::select", - "aten::size", - "aten::slice", - "aten::to", - "aten::transpose", - "aten::unsqueeze", - "aten::view", - "aten::zeros", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor", - "prim::TupleConstruct", - "prim::TupleUnpack" - ] - }, - "mpnet": { - "model_id": "microsoft/mpnet-base", - "ops": [ - "aten::abs", - "aten::add", - "aten::add_", - "aten::arange", - "aten::contiguous", - "aten::cumsum", - "aten::div", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::full_like", - "aten::gelu", - "aten::layer_norm", - "aten::linear", - "aten::log", - "aten::lt", - "aten::matmul", - "aten::min", - "aten::mul", - "aten::ne", - "aten::neg", - "aten::permute", - "aten::rsub", - "aten::select", - "aten::size", - "aten::slice", - "aten::softmax", - "aten::sub", - "aten::tanh", - "aten::to", - "aten::transpose", - "aten::type_as", - "aten::unsqueeze", - "aten::view", - "aten::where", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct" - ] - }, - "roberta": { - "model_id": "roberta-base", - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::cumsum", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::gather", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::mul", - "aten::ne", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::select", - "aten::size", - "aten::slice", - "aten::tanh", - "aten::to", - "aten::transpose", - "aten::type_as", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - }, - "xlm-roberta": { - "model_id": "xlm-roberta-base", - "ops": [ - "aten::Int", - "aten::ScalarImplicit", - "aten::__and__", - "aten::add", - "aten::arange", - "aten::contiguous", - "aten::cumsum", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::gather", - "aten::ge", - "aten::gelu", - "aten::index", - "aten::layer_norm", - "aten::linear", - "aten::mul", - "aten::ne", - "aten::new_ones", - "aten::reshape", - "aten::scaled_dot_product_attention", - "aten::select", - "aten::size", - "aten::slice", - "aten::tanh", - "aten::to", - "aten::transpose", - "aten::type_as", - "aten::unsqueeze", - "aten::view", - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::ListConstruct", - "prim::NumToTensor" - ] - } - } -} diff --git a/cmake/functions.cmake b/cmake/functions.cmake index 872bf3c6d..01502aaca 100644 --- a/cmake/functions.cmake +++ b/cmake/functions.cmake @@ -565,12 +565,5 @@ add_custom_target(check_style add_custom_target(precommit COMMENT "Running essential tasks prior to code commit" DEPENDS format test - COMMAND ${CMAKE_COMMAND} - -DSOURCE_DIR=${CMAKE_SOURCE_DIR} - -DVALIDATE_CONFIG=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/validation_models.json - -DVALIDATE_PT_DIR=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/es_it_models - -DVALIDATE_VERBOSE=TRUE - -DOPTIONAL=TRUE - -P ${CMAKE_SOURCE_DIR}/cmake/run-validation.cmake WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} ) diff --git a/cmake/run-validation.cmake b/cmake/run-validation.cmake deleted file mode 100644 index f1197eb19..000000000 --- a/cmake/run-validation.cmake +++ /dev/null @@ -1,186 +0,0 @@ -# -# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -# or more contributor license agreements. Licensed under the Elastic License -# 2.0 and the following additional limitation. Functionality enabled by the -# files subject to the Elastic License 2.0 may only be used in production when -# invoked by an Elasticsearch process with a license key installed that permits -# use of machine learning features. You may not use this file except in -# compliance with the Elastic License 2.0 and the foregoing additional -# limitation. -# - -# Portable CMake script that locates a Python 3 interpreter, ensures a -# virtual environment with the required packages exists, and then runs -# validate_allowlist.py. -# -# Required variables (passed via -D on command line): -# SOURCE_DIR - path to the repository root -# -# Optional variables: -# VALIDATE_CONFIG - path to validation_models.json -# VALIDATE_PT_DIR - directory of .pt files to validate -# VALIDATE_VERBOSE - if TRUE, pass --verbose to the script -# OPTIONAL - if TRUE, skip gracefully when Python 3 is not -# found or dependency installation fails (instead -# of failing the build). Intended for use when -# this script is invoked as part of a broader test -# target where the environment may not have Python -# or network access. - -cmake_minimum_required(VERSION 3.16) - -if(NOT DEFINED SOURCE_DIR) - message(FATAL_ERROR "SOURCE_DIR must be defined") -endif() - -# Helper: emit a FATAL_ERROR or a WARNING+return depending on OPTIONAL. -macro(_validation_fail _msg) - if(DEFINED OPTIONAL AND OPTIONAL) - message(WARNING "Skipping validation: ${_msg}") - return() - else() - message(FATAL_ERROR "${_msg}") - endif() -endmacro() - -set(_tools_dir "${SOURCE_DIR}/dev-tools/extract_model_ops") -set(_venv_dir "${_tools_dir}/.venv") -set(_requirements "${_tools_dir}/requirements.txt") -set(_validate_script "${_tools_dir}/validate_allowlist.py") - -# --- Locate a Python 3 interpreter --- -# Try names in order of preference. On Linux build machines Python may -# only be available as python3.12 (installed via make altinstall). -# On Windows the canonical name is just "python". -find_program(_python_path - NAMES python3 python3.12 python3.11 python3.10 python3.9 python - DOC "Python 3 interpreter" -) - -if(NOT _python_path) - _validation_fail( - "No Python 3 interpreter found on PATH.\n" - "Install Python 3 or ensure it is on your PATH.") -endif() - -# Verify it is actually Python 3 (guards against "python" being Python 2). -execute_process( - COMMAND "${_python_path}" --version - OUTPUT_VARIABLE _py_version_out - ERROR_VARIABLE _py_version_out - RESULT_VARIABLE _py_rc - OUTPUT_STRIP_TRAILING_WHITESPACE -) -if(NOT _py_rc EQUAL 0 OR NOT _py_version_out MATCHES "Python 3\\.") - _validation_fail( - "Found ${_python_path} but it is not Python 3 (${_py_version_out}).") -endif() -message(STATUS "Found Python 3: ${_python_path} (${_py_version_out})") - -# --- Platform-specific venv paths --- -if(CMAKE_HOST_WIN32) - set(_venv_python "${_venv_dir}/Scripts/python.exe") - set(_venv_pip "${_venv_dir}/Scripts/pip.exe") -else() - set(_venv_python "${_venv_dir}/bin/python3") - set(_venv_pip "${_venv_dir}/bin/pip") -endif() - -# --- Create virtual environment if it does not exist --- -if(NOT EXISTS "${_venv_python}") - message(STATUS "Creating virtual environment in ${_venv_dir} ...") - execute_process( - COMMAND "${_python_path}" -m venv "${_venv_dir}" - RESULT_VARIABLE _venv_rc - ) - if(NOT _venv_rc EQUAL 0) - _validation_fail("Failed to create virtual environment (exit ${_venv_rc})") - endif() -endif() - -# --- Install / update dependencies when requirements.txt is newer --- -set(_stamp "${_venv_dir}/.requirements.stamp") -set(_needs_install FALSE) - -if(NOT EXISTS "${_stamp}") - set(_needs_install TRUE) -else() - file(TIMESTAMP "${_requirements}" _req_ts "%Y%m%d%H%M%S" UTC) - file(TIMESTAMP "${_stamp}" _stamp_ts "%Y%m%d%H%M%S" UTC) - if(_req_ts STRGREATER _stamp_ts) - set(_needs_install TRUE) - endif() -endif() - -if(_needs_install) - message(STATUS "Installing/updating Python dependencies ...") - execute_process( - COMMAND "${_venv_pip}" install --quiet --upgrade pip - RESULT_VARIABLE _pip_rc - ) - if(NOT _pip_rc EQUAL 0) - message(WARNING "pip upgrade failed (exit ${_pip_rc}) — continuing anyway") - endif() - - execute_process( - COMMAND "${_venv_pip}" install --quiet -r "${_requirements}" - RESULT_VARIABLE _pip_rc - ) - if(NOT _pip_rc EQUAL 0) - _validation_fail( - "Failed to install dependencies from ${_requirements} (exit ${_pip_rc}).\n" - "This may indicate no network access is available.") - endif() - - file(WRITE "${_stamp}" "installed") -endif() - -# --- Ensure the venv's torch libraries take precedence --- -# When a locally-built libtorch is installed in a system path (e.g. -# /usr/local/lib on macOS), the pip-installed torch package's -# libtorch_python will pick up the wrong libtorch_cpu at load time. -# Prepending the venv's torch/lib directory to the dynamic library -# search path forces the pip-bundled libraries to be found first. -if(CMAKE_HOST_WIN32) - set(_venv_site_packages "${_venv_dir}/Lib/site-packages") -else() - # Discover the site-packages directory (Python version varies) - file(GLOB _venv_site_packages "${_venv_dir}/lib/python*/site-packages") -endif() -set(_torch_lib_dir "${_venv_site_packages}/torch/lib") - -if(EXISTS "${_torch_lib_dir}") - if(CMAKE_HOST_APPLE) - set(ENV{DYLD_LIBRARY_PATH} "${_torch_lib_dir}:$ENV{DYLD_LIBRARY_PATH}") - elseif(NOT CMAKE_HOST_WIN32) - set(ENV{LD_LIBRARY_PATH} "${_torch_lib_dir}:$ENV{LD_LIBRARY_PATH}") - endif() - message(STATUS "Prepended ${_torch_lib_dir} to dynamic library search path") -endif() - -# --- Build the command line for validate_allowlist.py --- -set(_cmd "${_venv_python}" "${_validate_script}") - -if(DEFINED VALIDATE_CONFIG) - list(APPEND _cmd "--config" "${VALIDATE_CONFIG}") -endif() - -if(DEFINED VALIDATE_PT_DIR) - list(APPEND _cmd "--pt-dir" "${VALIDATE_PT_DIR}") -endif() - -if(DEFINED VALIDATE_VERBOSE AND VALIDATE_VERBOSE) - list(APPEND _cmd "--verbose") -endif() - -message(STATUS "Running: ${_cmd}") - -execute_process( - COMMAND ${_cmd} - WORKING_DIRECTORY "${SOURCE_DIR}" - RESULT_VARIABLE _validate_rc -) - -if(NOT _validate_rc EQUAL 0) - _validation_fail("Validation failed (exit ${_validate_rc})") -endif() diff --git a/dev-tools/extract_model_ops/.gitignore b/dev-tools/extract_model_ops/.gitignore deleted file mode 100644 index 21d0b898f..000000000 --- a/dev-tools/extract_model_ops/.gitignore +++ /dev/null @@ -1 +0,0 @@ -.venv/ diff --git a/dev-tools/extract_model_ops/README.md b/dev-tools/extract_model_ops/README.md deleted file mode 100644 index f7b7f2f39..000000000 --- a/dev-tools/extract_model_ops/README.md +++ /dev/null @@ -1,166 +0,0 @@ -# extract_model_ops - -Developer tools for maintaining and validating the TorchScript operation -allowlist in `bin/pytorch_inference/CSupportedOperations.cc`. - -This directory contains two scripts that share the same Python environment: - -| Script | Purpose | -|---|---| -| `extract_model_ops.py` | Generate the C++ `ALLOWED_OPERATIONS` set from reference models | -| `validate_allowlist.py` | Verify the allowlist accepts all supported models (no false positives) | - -## Setup - -Create a Python virtual environment and install the dependencies: - -```bash -cd dev-tools/extract_model_ops -python3 -m venv .venv -source .venv/bin/activate -pip install -r requirements.txt -``` - -If any of the reference models are gated, set a HuggingFace token: - -```bash -export HF_TOKEN="hf_..." -``` - -## extract_model_ops.py - -Traces each model in `reference_models.json`, collects the TorchScript -operations from the inlined forward graph, and outputs the union as a -sorted list or a ready-to-paste C++ initializer. - -### When to run - -- A new transformer architecture is added to the supported set. -- The PyTorch (libtorch) version used by ml-cpp is upgraded. -- You need to inspect which operations a particular model uses. - -### Usage - -```bash -# Print the sorted union of all operations (default) -python3 extract_model_ops.py - -# Print a ready-to-paste C++ initializer list -python3 extract_model_ops.py --cpp - -# Also show per-model breakdowns -python3 extract_model_ops.py --per-model --cpp - -# Generate the golden file for the C++ allowlist drift test -python3 extract_model_ops.py --golden \ - ../../bin/pytorch_inference/unittest/testfiles/reference_model_ops.json - -# Use a custom config file -python3 extract_model_ops.py --config /path/to/models.json -``` - -## validate\_allowlist.py - -Parses `ALLOWED_OPERATIONS` and `FORBIDDEN_OPERATIONS` directly from -`CSupportedOperations.cc`, then traces every model in a config file and -checks that each model's operations are accepted. Exits non-zero if -any model would be rejected (a false positive). - -### When to run - -- After regenerating `ALLOWED_OPERATIONS` with `extract_model_ops.py`. -- After adding new models to `validation_models.json`. -- As a pre-merge check for any PR that touches the allowlist or the - graph validation logic. - -### Usage - -```bash -# Validate against the default set (validation_models.json) -python3 validate_allowlist.py - -# Validate with verbose per-model op counts -python3 validate_allowlist.py --verbose - -# Validate against a custom model set -python3 validate_allowlist.py --config /path/to/models.json -``` - -The script can also be run via the CMake `validate_pytorch_inference_models` -target, which automatically locates a Python 3 interpreter, creates a venv, -and installs dependencies — no manual setup required: - -```bash -cmake --build cmake-build-relwithdebinfo -t validate_pytorch_inference_models -``` - -The CMake target searches for `python3`, `python3.12`, `python3.11`, -`python3.10`, `python3.9`, and `python` (in that order), accepting the -first one that reports Python 3.x. This handles Linux build machines -where Python is only available as `python3.12` (via `make altinstall`) -as well as Windows where the canonical name is `python`. - -## Configuration files - -| File | Used by | Purpose | -|---|---|---| -| `reference_models.json` | `extract_model_ops.py` | Models whose ops form the allowlist | -| `validation_models.json` | `validate_allowlist.py` | Superset including task-specific models (NER, sentiment) from `bin/pytorch_inference/examples/` | - -Each file maps a short architecture name to a HuggingFace model identifier: - -```json -{ - "bert": "bert-base-uncased", - "roberta": "roberta-base" -} -``` - -To add a new architecture, append an entry to `reference_models.json`, -re-run `extract_model_ops.py --cpp`, and update `CSupportedOperations.cc`. -Then add the same entry (plus any task-specific variants) to -`validation_models.json` and run `validate_allowlist.py` to confirm -there are no false positives. Finally, regenerate the golden file -(see below). - -## Golden file for allowlist drift detection - -The C++ test `testAllowlistCoversReferenceModels` loads a golden JSON -file containing per-architecture op sets and verifies every op is in -`ALLOWED_OPERATIONS` and none are in `FORBIDDEN_OPERATIONS`. This -catches allowlist regressions in CI without requiring Python or network -access. - -The golden file lives at: -`bin/pytorch_inference/unittest/testfiles/reference_model_ops.json` - -### When to regenerate - -- After upgrading the PyTorch (libtorch) version. -- After adding or removing a supported architecture. -- After modifying `ALLOWED_OPERATIONS` or `FORBIDDEN_OPERATIONS`. - -### How to regenerate - -```bash -cd dev-tools/extract_model_ops -source .venv/bin/activate -python3 extract_model_ops.py --golden \ - ../../bin/pytorch_inference/unittest/testfiles/reference_model_ops.json -``` - -If the regenerated file introduces ops not in the allowlist, the C++ -test will fail until `CSupportedOperations.cc` is updated. - -## How it works - -1. Each reference model is loaded via `transformers.AutoModel` with - `torchscript=True` in the config. -2. The model is traced with `torch.jit.trace` using a short dummy input - (falls back to `torch.jit.script` if tracing fails). -3. All method calls in the forward graph are inlined via - `torch._C._jit_pass_inline` so that operations inside submodules - are visible. -4. Every node's operation name (`node.kind()`) is collected, recursing - into sub-blocks (e.g. inside `prim::If` / `prim::Loop` nodes). -5. The union across all models is reported. diff --git a/dev-tools/extract_model_ops/es_it_models/README.md b/dev-tools/extract_model_ops/es_it_models/README.md deleted file mode 100644 index a3997d2ef..000000000 --- a/dev-tools/extract_model_ops/es_it_models/README.md +++ /dev/null @@ -1,41 +0,0 @@ -# Elasticsearch Integration Test Models - -Pre-saved TorchScript `.pt` files extracted from the base64-encoded models -in the Elasticsearch Java integration tests. These are tiny synthetic models -(not real transformer architectures) used to test the `pytorch_inference` -loading and evaluation pipeline. - -| File | Source | Description | -|------|--------|-------------| -| `supersimple_pytorch_model_it.pt` | `PyTorchModelIT.java` | Returns `torch.ones` of shape `(batch, 2)` | -| `tiny_text_expansion.pt` | `TextExpansionQueryIT.java` | Sparse weight vector sized by max input ID | -| `tiny_text_embedding.pt` | `TextEmbeddingQueryIT.java` | Random 100-dim embedding seeded by input hash | - -## Regenerating - -If the Java test models change, re-extract them by running the generation -snippet from this repository's root: - -```bash -python3 -c " -import re, base64, os - -JAVA_DIR = '/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration' -OUTPUT_DIR = 'dev-tools/extract_model_ops/es_it_models' - -SOURCES = { - 'supersimple_pytorch_model_it.pt': ('PyTorchModelIT.java', 'BASE_64_ENCODED_MODEL'), - 'tiny_text_expansion.pt': ('TextExpansionQueryIT.java', 'BASE_64_ENCODED_MODEL'), - 'tiny_text_embedding.pt': ('TextEmbeddingQueryIT.java', 'BASE_64_ENCODED_MODEL'), -} -os.makedirs(OUTPUT_DIR, exist_ok=True) -for out_name, (java_file, var_name) in SOURCES.items(): - with open(os.path.join(JAVA_DIR, java_file)) as f: - src = f.read() - m = re.search(rf'{var_name}\s*=\s*(\".*?\");', src, re.DOTALL) - b64 = re.sub(r'\"\s*\+\s*\"', '', m.group(1)).strip('\"').replace('\n', '').replace(' ', '') - with open(os.path.join(OUTPUT_DIR, out_name), 'wb') as f: - f.write(base64.b64decode(b64)) - print(f'Wrote {out_name}') -" -``` diff --git a/dev-tools/extract_model_ops/es_it_models/supersimple_pytorch_model_it.pt b/dev-tools/extract_model_ops/es_it_models/supersimple_pytorch_model_it.pt deleted file mode 100644 index 0eecbb1b3..000000000 Binary files a/dev-tools/extract_model_ops/es_it_models/supersimple_pytorch_model_it.pt and /dev/null differ diff --git a/dev-tools/extract_model_ops/es_it_models/tiny_text_embedding.pt b/dev-tools/extract_model_ops/es_it_models/tiny_text_embedding.pt deleted file mode 100644 index 933a50b95..000000000 Binary files a/dev-tools/extract_model_ops/es_it_models/tiny_text_embedding.pt and /dev/null differ diff --git a/dev-tools/extract_model_ops/es_it_models/tiny_text_expansion.pt b/dev-tools/extract_model_ops/es_it_models/tiny_text_expansion.pt deleted file mode 100644 index a4c0abe6a..000000000 Binary files a/dev-tools/extract_model_ops/es_it_models/tiny_text_expansion.pt and /dev/null differ diff --git a/dev-tools/extract_model_ops/extract_model_ops.py b/dev-tools/extract_model_ops/extract_model_ops.py deleted file mode 100644 index 562590c31..000000000 --- a/dev-tools/extract_model_ops/extract_model_ops.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -# or more contributor license agreements. Licensed under the Elastic License -# 2.0 and the following additional limitation. Functionality enabled by the -# files subject to the Elastic License 2.0 may only be used in production when -# invoked by an Elasticsearch process with a license key installed that permits -# use of machine learning features. You may not use this file except in -# compliance with the Elastic License 2.0 and the foregoing additional -# limitation. -# -"""Extract TorchScript operation sets from supported HuggingFace transformer architectures. - -This developer tool traces/scripts reference models and collects the set of -TorchScript operations that appear in their forward() computation graphs. -The output is a sorted, de-duplicated union of all operations which can be -used to build the C++ allowlist in CSupportedOperations.h. - -Usage: - python3 extract_model_ops.py [--per-model] [--cpp] [--golden OUTPUT] [--config CONFIG] - -Flags: - --per-model Print the op set for each model individually. - --cpp Print the union as a C++ initializer list. - --golden OUTPUT Write per-model op sets as a JSON golden file for the - C++ allowlist drift test. - --config CONFIG Path to the reference models JSON config file. - Defaults to reference_models.json in the same directory. -""" - -import argparse -import json -import sys -from pathlib import Path - -import torch - -from torchscript_utils import ( - collect_inlined_ops, - load_and_trace_hf_model, - load_model_config, -) - -SCRIPT_DIR = Path(__file__).resolve().parent -DEFAULT_CONFIG = SCRIPT_DIR / "reference_models.json" - - -def extract_ops_for_model(model_name: str, - quantize: bool = False) -> set[str] | None: - """Trace a HuggingFace model and return its TorchScript op set. - - Returns None if the model could not be loaded or traced. - """ - label = f"{model_name} (quantized)" if quantize else model_name - print(f" Loading {label}...", file=sys.stderr) - traced = load_and_trace_hf_model(model_name, quantize=quantize) - if traced is None: - return None - return collect_inlined_ops(traced) - - -def format_cpp_initializer(ops: set[str]) -> str: - """Format the op set as a C++ initializer list for std::unordered_set.""" - sorted_ops = sorted(ops) - lines = [] - for op in sorted_ops: - lines.append(f' "{op}"sv,') - return "{\n" + "\n".join(lines) + "\n}" - - -def main(): - parser = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument("--per-model", action="store_true", - help="Print per-model op sets") - parser.add_argument("--cpp", action="store_true", - help="Print union as C++ initializer") - parser.add_argument("--golden", type=Path, default=None, metavar="OUTPUT", - help="Write per-model op sets as a JSON golden file") - parser.add_argument("--config", type=Path, default=DEFAULT_CONFIG, - help="Path to reference_models.json config file") - args = parser.parse_args() - - reference_models = load_model_config(args.config) - - per_model_ops = {} - union_ops = set() - - print("Extracting TorchScript ops from supported architectures...", - file=sys.stderr) - - failed = [] - for arch, spec in reference_models.items(): - ops = extract_ops_for_model(spec["model_id"], - quantize=spec["quantized"]) - if ops is None: - failed.append(arch) - print(f" {arch}: FAILED", file=sys.stderr) - continue - per_model_ops[arch] = ops - union_ops.update(ops) - print(f" {arch}: {len(ops)} ops", file=sys.stderr) - - print(f"\nTotal union: {len(union_ops)} unique ops", file=sys.stderr) - if failed: - print(f"Failed models: {', '.join(failed)}", file=sys.stderr) - - if args.golden: - golden = { - "pytorch_version": torch.__version__, - "models": { - arch: { - "model_id": reference_models[arch]["model_id"], - "quantized": reference_models[arch]["quantized"], - "ops": sorted(ops), - } - for arch, ops in sorted(per_model_ops.items()) - }, - } - args.golden.parent.mkdir(parents=True, exist_ok=True) - with open(args.golden, "w") as f: - json.dump(golden, f, indent=2) - f.write("\n") - print(f"Wrote golden file to {args.golden} " - f"({len(per_model_ops)} models, " - f"{len(union_ops)} unique ops)", file=sys.stderr) - - if args.per_model: - for arch, ops in sorted(per_model_ops.items()): - spec = reference_models[arch] - label = spec["model_id"] - if spec["quantized"]: - label += " (quantized)" - print(f"\n=== {arch} ({label}) ===") - for op in sorted(ops): - print(f" {op}") - - if args.cpp: - print("\n// C++ initializer for SUPPORTED_OPERATIONS:") - print(format_cpp_initializer(union_ops)) - elif not args.golden: - print("\n// Sorted union of all operations:") - for op in sorted(union_ops): - print(op) - - -if __name__ == "__main__": - main() diff --git a/dev-tools/extract_model_ops/reference_models.json b/dev-tools/extract_model_ops/reference_models.json deleted file mode 100644 index 76aefef01..000000000 --- a/dev-tools/extract_model_ops/reference_models.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "bert": "bert-base-uncased", - "roberta": "roberta-base", - "distilbert": "distilbert-base-uncased", - "electra": "google/electra-small-discriminator", - "mpnet": "microsoft/mpnet-base", - "deberta": "microsoft/deberta-base", - "dpr": "facebook/dpr-ctx_encoder-single-nq-base", - "mobilebert": "google/mobilebert-uncased", - "xlm-roberta": "xlm-roberta-base", - "elastic-bge-m3": "elastic/bge-m3", - "elastic-distilbert-cased-ner": "elastic/distilbert-base-cased-finetuned-conll03-english", - "elastic-distilbert-uncased-ner": "elastic/distilbert-base-uncased-finetuned-conll03-english", - "elastic-eis-elser-v2": "elastic/eis-elser-v2", - "elastic-elser-v2": "elastic/elser-v2", - "elastic-hugging-face-elser": "elastic/hugging-face-elser", - "elastic-multilingual-e5-small-optimized": "elastic/multilingual-e5-small-optimized", - "elastic-splade-v3": "elastic/splade-v3", - "elastic-test-elser-v2": "elastic/test-elser-v2", - - "_comment:quantized": "Quantized variants: Eland applies torch.quantization.quantize_dynamic on nn.Linear layers when importing models. These produce quantized::* ops not present in the standard traced graphs above.", - "elastic-elser-v2-quantized": {"model_id": "elastic/elser-v2", "quantized": true}, - "elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantized": true}, - "elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true} -} diff --git a/dev-tools/extract_model_ops/requirements.txt b/dev-tools/extract_model_ops/requirements.txt deleted file mode 100644 index 70d0ebb78..000000000 --- a/dev-tools/extract_model_ops/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -torch==2.7.1 -transformers>=4.40.0 -sentencepiece>=0.2.0 -protobuf>=5.0.0 diff --git a/dev-tools/extract_model_ops/torchscript_utils.py b/dev-tools/extract_model_ops/torchscript_utils.py deleted file mode 100644 index 33042f261..000000000 --- a/dev-tools/extract_model_ops/torchscript_utils.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -# or more contributor license agreements. Licensed under the Elastic License -# 2.0 and the following additional limitation. Functionality enabled by the -# files subject to the Elastic License 2.0 may only be used in production when -# invoked by an Elasticsearch process with a license key installed that permits -# use of machine learning features. You may not use this file except in -# compliance with the Elastic License 2.0 and the foregoing additional -# limitation. -# -"""Shared utilities for extracting and inspecting TorchScript operations.""" - -import json -import os -import sys -from pathlib import Path - -import torch -from transformers import AutoConfig, AutoModel, AutoTokenizer - - -def load_model_config(config_path: Path) -> dict[str, dict]: - """Load a model config JSON file and normalise entries. - - Each entry is either a plain model-name string or a dict with - ``model_id`` (required) and optional ``quantized`` boolean. All - entries are normalised to ``{"model_id": str, "quantized": bool}``. - Keys starting with ``_comment`` are silently skipped. - - Raises ``ValueError`` for malformed entries so that config problems - are caught early with an actionable message. - """ - with open(config_path) as f: - raw = json.load(f) - - models: dict[str, dict] = {} - for key, value in raw.items(): - if key.startswith("_comment"): - continue - if isinstance(value, str): - models[key] = {"model_id": value, "quantized": False} - elif isinstance(value, dict): - if "model_id" not in value: - raise ValueError( - f"Config entry {key!r} is a dict but missing required " - f"'model_id' key: {value!r}") - models[key] = { - "model_id": value["model_id"], - "quantized": value.get("quantized", False), - } - else: - raise ValueError( - f"Config entry {key!r} has unsupported type " - f"{type(value).__name__}: {value!r}. " - f"Expected a model name string or a dict with 'model_id'.") - return models - - -def collect_graph_ops(graph) -> set[str]: - """Collect all operation names from a TorchScript graph, including blocks.""" - ops = set() - for node in graph.nodes(): - ops.add(node.kind()) - for block in node.blocks(): - ops.update(collect_graph_ops(block)) - return ops - - -def collect_inlined_ops(module) -> set[str]: - """Clone the forward graph, inline all calls, and return the op set.""" - graph = module.forward.graph.copy() - torch._C._jit_pass_inline(graph) - return collect_graph_ops(graph) - - -def load_and_trace_hf_model(model_name: str, quantize: bool = False): - """Load a HuggingFace model, tokenize sample input, and trace to TorchScript. - - When *quantize* is True the model is dynamically quantized (nn.Linear - layers converted to quantized::linear_dynamic) before tracing. This - mirrors what Eland does when importing models for Elasticsearch. - - Returns the traced module, or None if the model could not be loaded or traced. - """ - token = os.environ.get("HF_TOKEN") - - try: - tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) - config = AutoConfig.from_pretrained( - model_name, torchscript=True, token=token) - model = AutoModel.from_pretrained( - model_name, config=config, token=token) - model.eval() - except Exception as exc: - print(f" LOAD ERROR: {exc}", file=sys.stderr) - return None - - if quantize: - try: - model = torch.quantization.quantize_dynamic( - model, {torch.nn.Linear}, dtype=torch.qint8) - print(" Applied dynamic quantization (nn.Linear -> qint8)", - file=sys.stderr) - except Exception as exc: - print(f" QUANTIZE ERROR: {exc}", file=sys.stderr) - return None - - inputs = tokenizer( - "This is a sample input for graph extraction.", - return_tensors="pt", padding="max_length", - max_length=32, truncation=True) - - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - - try: - return torch.jit.trace( - model, (input_ids, attention_mask), strict=False) - except Exception as exc: - print(f" TRACE WARNING: {exc}", file=sys.stderr) - print(" Falling back to torch.jit.script...", file=sys.stderr) - try: - return torch.jit.script(model) - except Exception as exc2: - print(f" SCRIPT ERROR: {exc2}", file=sys.stderr) - return None diff --git a/dev-tools/extract_model_ops/validate_allowlist.py b/dev-tools/extract_model_ops/validate_allowlist.py deleted file mode 100644 index 828749dbc..000000000 --- a/dev-tools/extract_model_ops/validate_allowlist.py +++ /dev/null @@ -1,201 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -# or more contributor license agreements. Licensed under the Elastic License -# 2.0 and the following additional limitation. Functionality enabled by the -# files subject to the Elastic License 2.0 may only be used in production when -# invoked by an Elasticsearch process with a license key installed that permits -# use of machine learning features. You may not use this file except in -# compliance with the Elastic License 2.0 and the foregoing additional -# limitation. -# -"""Validate that the C++ operation allowlist accepts all supported model architectures. - -Traces each model listed in a JSON config file, extracts its TorchScript -operations (using the same inlining approach as the C++ validator), and -checks every operation against the ALLOWED_OPERATIONS and FORBIDDEN_OPERATIONS -sets parsed from CSupportedOperations.cc. - -This is the Python-side equivalent of the C++ CModelGraphValidator and is -intended as an integration test: if any legitimate model produces an -operation that the C++ code would reject, this script exits non-zero. - -Exit codes: - 0 All models pass (no false positives). - 1 At least one model was rejected or a model failed to load/trace. - -Usage: - python3 validate_allowlist.py [--config CONFIG] [--verbose] -""" - -import argparse -import re -import sys -from pathlib import Path - -import torch - -from torchscript_utils import ( - collect_graph_ops, - collect_inlined_ops, - load_and_trace_hf_model, - load_model_config, -) - -SCRIPT_DIR = Path(__file__).resolve().parent -REPO_ROOT = SCRIPT_DIR.parents[1] -DEFAULT_CONFIG = SCRIPT_DIR / "validation_models.json" -SUPPORTED_OPS_CC = REPO_ROOT / "bin" / "pytorch_inference" / "CSupportedOperations.cc" - - -def parse_string_set_from_cc(path: Path, variable_name: str) -> set[str]: - """Extract a set of string literals from a C++ TStringViewSet definition.""" - text = path.read_text() - pattern = rf'{re.escape(variable_name)}\s*=\s*\{{(.*?)\}};' - match = re.search(pattern, text, re.DOTALL) - if not match: - raise RuntimeError(f"Could not find {variable_name} in {path}") - block = match.group(1) - return set(re.findall(r'"([^"]+)"', block)) - - -def load_cpp_sets() -> tuple[set[str], set[str]]: - """Parse ALLOWED_OPERATIONS and FORBIDDEN_OPERATIONS from the C++ source.""" - allowed = parse_string_set_from_cc(SUPPORTED_OPS_CC, "ALLOWED_OPERATIONS") - forbidden = parse_string_set_from_cc(SUPPORTED_OPS_CC, "FORBIDDEN_OPERATIONS") - return allowed, forbidden - - -def load_pt_and_collect_ops(pt_path: str) -> set[str] | None: - """Load a saved TorchScript .pt file, inline, and return its op set.""" - try: - module = torch.jit.load(pt_path) - return collect_inlined_ops(module) - except Exception as exc: - print(f" LOAD ERROR: {exc}", file=sys.stderr) - return None - - -def check_ops(ops: set[str], - allowed: set[str], - forbidden: set[str], - verbose: bool) -> bool: - """Check an op set against allowed/forbidden lists. Returns True if all pass.""" - forbidden_found = sorted(ops & forbidden) - unrecognised = sorted(ops - allowed - forbidden) - - if verbose: - print(f" {len(ops)} distinct ops", file=sys.stderr) - - if forbidden_found: - print(f" FORBIDDEN: {forbidden_found}", file=sys.stderr) - if unrecognised: - print(f" UNRECOGNISED: {unrecognised}", file=sys.stderr) - - if not forbidden_found and not unrecognised: - print(f" PASS", file=sys.stderr) - return True - - print(f" FAIL", file=sys.stderr) - return False - - -def validate_model(model_name: str, - allowed: set[str], - forbidden: set[str], - verbose: bool, - quantize: bool = False) -> bool: - """Validate one HuggingFace model. Returns True if all ops pass.""" - label = f"{model_name} (quantized)" if quantize else model_name - print(f" {label}...", file=sys.stderr) - traced = load_and_trace_hf_model(model_name, quantize=quantize) - if traced is None: - print(f" FAILED (could not load/trace)", file=sys.stderr) - return False - ops = collect_inlined_ops(traced) - return check_ops(ops, allowed, forbidden, verbose) - - -def validate_pt_file(name: str, - pt_path: str, - allowed: set[str], - forbidden: set[str], - verbose: bool) -> bool: - """Validate a local TorchScript .pt file. Returns True if all ops pass.""" - print(f" {name} ({pt_path})...", file=sys.stderr) - ops = load_pt_and_collect_ops(pt_path) - if ops is None: - print(f" FAILED (could not load)", file=sys.stderr) - return False - return check_ops(ops, allowed, forbidden, verbose) - - -def main(): - parser = argparse.ArgumentParser( - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument( - "--config", type=Path, default=DEFAULT_CONFIG, - help="Path to reference_models.json (default: %(default)s)") - parser.add_argument( - "--pt-dir", type=Path, default=None, - help="Directory of pre-saved .pt TorchScript files to validate") - parser.add_argument( - "--verbose", action="store_true", - help="Print per-model op counts") - args = parser.parse_args() - - print(f"PyTorch version: {torch.__version__}", file=sys.stderr) - - allowed, forbidden = load_cpp_sets() - print(f"Parsed {len(allowed)} allowed ops and {len(forbidden)} " - f"forbidden ops from {SUPPORTED_OPS_CC.name}", file=sys.stderr) - - results: dict[str, bool] = {} - - models = load_model_config(args.config) - - print(f"Validating {len(models)} HuggingFace models from " - f"{args.config.name}...", file=sys.stderr) - - for arch, spec in models.items(): - results[arch] = validate_model( - spec["model_id"], allowed, forbidden, args.verbose, - quantize=spec["quantized"]) - - if args.pt_dir and args.pt_dir.is_dir(): - pt_files = sorted(args.pt_dir.glob("*.pt")) - if pt_files: - print(f"Validating {len(pt_files)} local .pt files from " - f"{args.pt_dir}...", file=sys.stderr) - for pt_path in pt_files: - name = pt_path.stem - results[f"pt:{name}"] = validate_pt_file( - name, str(pt_path), allowed, forbidden, args.verbose) - - print(file=sys.stderr) - print("=" * 60, file=sys.stderr) - all_pass = all(results.values()) - for key, passed in results.items(): - status = "PASS" if passed else "FAIL" - if key.startswith("pt:"): - print(f" {key}: {status}", file=sys.stderr) - else: - spec = models[key] - label = spec["model_id"] - if spec["quantized"]: - label += " (quantized)" - print(f" {key} ({label}): {status}", file=sys.stderr) - - print("=" * 60, file=sys.stderr) - if all_pass: - print("All models PASS - no false positives.", file=sys.stderr) - else: - failed = [a for a, p in results.items() if not p] - print(f"FAILED models: {', '.join(failed)}", file=sys.stderr) - - sys.exit(0 if all_pass else 1) - - -if __name__ == "__main__": - main() diff --git a/dev-tools/extract_model_ops/validation_models.json b/dev-tools/extract_model_ops/validation_models.json deleted file mode 100644 index 0c853cdc5..000000000 --- a/dev-tools/extract_model_ops/validation_models.json +++ /dev/null @@ -1,33 +0,0 @@ -{ - "bert": "bert-base-uncased", - "roberta": "roberta-base", - "distilbert": "distilbert-base-uncased", - "electra": "google/electra-small-discriminator", - "mpnet": "microsoft/mpnet-base", - "deberta": "microsoft/deberta-base", - "dpr": "facebook/dpr-ctx_encoder-single-nq-base", - "mobilebert": "google/mobilebert-uncased", - "xlm-roberta": "xlm-roberta-base", - - "elastic-bge-m3": "elastic/bge-m3", - "elastic-distilbert-cased-ner": "elastic/distilbert-base-cased-finetuned-conll03-english", - "elastic-distilbert-uncased-ner": "elastic/distilbert-base-uncased-finetuned-conll03-english", - "elastic-eis-elser-v2": "elastic/eis-elser-v2", - "elastic-elser-v2": "elastic/elser-v2", - "elastic-hugging-face-elser": "elastic/hugging-face-elser", - "elastic-multilingual-e5-small-optimized": "elastic/multilingual-e5-small-optimized", - "elastic-splade-v3": "elastic/splade-v3", - "elastic-test-elser-v2": "elastic/test-elser-v2", - - "elastic-elser-v2-quantized": {"model_id": "elastic/elser-v2", "quantized": true}, - "elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantized": true}, - "elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true}, - - "ner-dslim-bert-base": "dslim/bert-base-NER", - "sentiment-distilbert-sst2": "distilbert-base-uncased-finetuned-sst-2-english", - - "es-multilingual-e5-small": "intfloat/multilingual-e5-small", - "es-all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2", - "es-cross-encoder-ms-marco": "cross-encoder/ms-marco-MiniLM-L-6-v2", - "es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base" -} diff --git a/dev-tools/generate_malicious_models.py b/dev-tools/generate_malicious_models.py deleted file mode 100644 index 21afe1110..000000000 --- a/dev-tools/generate_malicious_models.py +++ /dev/null @@ -1,274 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -# or more contributor license agreements. Licensed under the Elastic License -# 2.0 and the following additional limitation. Functionality enabled by the -# files subject to the Elastic License 2.0 may only be used in production when -# invoked by an Elasticsearch process with a license key installed that permits -# use of machine learning features. You may not use this file except in -# compliance with the Elastic License 2.0 and the foregoing additional -# limitation. -# -"""Generate malicious TorchScript model fixtures for validator integration tests. - -Each model is designed to exercise a specific attack vector that the -CModelGraphValidator must detect and reject. - -Usage: - python3 generate_malicious_models.py [output_dir] - -The output directory defaults to the same directory as this script. -""" - -import os -import sys -from pathlib import Path - -import torch -from torch import Tensor -from typing import Optional - - -# --- Malicious model definitions --- - - -class FileReaderModel(torch.nn.Module): - """Uses aten::from_file to read arbitrary files from disk.""" - def forward(self, x: Tensor) -> Tensor: - stolen = torch.from_file("/etc/passwd", size=100) - return stolen - - -class MixedFileReaderModel(torch.nn.Module): - """Mixes allowed ops with a forbidden aten::from_file call.""" - def forward(self, x: Tensor) -> Tensor: - y = x + x - z = torch.from_file("/etc/shadow", size=10) - return y + z - - -class HiddenInSubmodule(torch.nn.Module): - """Hides aten::sin (unrecognised) three levels deep in submodules.""" - def __init__(self): - super().__init__() - self.inner = _Inner() - - def forward(self, x: Tensor) -> Tensor: - y = x * x - return self.inner(y) - - -class _Inner(torch.nn.Module): - def __init__(self): - super().__init__() - self.leaf = _Leaf() - - def forward(self, x: Tensor) -> Tensor: - return self.leaf(x) + x - - -class _Leaf(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - return torch.sin(x) - - -class ConditionalMalicious(torch.nn.Module): - """Hides an unrecognised op (aten::sin) inside one branch of a conditional.""" - def forward(self, x: Tensor) -> Tensor: - if x.sum() > 0: - return torch.sin(x) - else: - return x + x - - -class ManyUnrecognisedOps(torch.nn.Module): - """Uses several different unrecognised ops to simulate an unexpected arch.""" - def forward(self, x: Tensor) -> Tensor: - a = torch.sin(x) - b = torch.cos(x) - c = torch.tan(x) - d = torch.exp(x) - return a + b + c + d - - -class FileReaderInSubmodule(torch.nn.Module): - """Hides the forbidden aten::from_file inside a submodule.""" - def __init__(self): - super().__init__() - self.reader = _FileReaderChild() - - def forward(self, x: Tensor) -> Tensor: - return x + self.reader(x) - - -class _FileReaderChild(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - return torch.from_file("/tmp/secret", size=10) - - -# --- Sandbox2 attack models (PR #2873) --- -# -# These reproduce real-world attack vectors that exploit torch.as_strided -# to leak heap addresses and build ROP chains. The graph validator must -# reject them because as_strided (and several helper ops) are not in the -# transformer-architecture allowlist. - - -class HeapLeakModel(torch.nn.Module): - """Leaks heap addresses via torch.as_strided with a malicious storage offset. - - The attack scans the heap for libtorch pointers by reading out-of-bounds - memory through a strided view. Key unrecognised ops: aten::as_strided, - aten::item, aten::__rshift__. - """ - def find_heap_overlap(self, tensors: Tensor) -> tuple[list[int], Tensor]: - evil_length = 0x900 * 8 - leak = torch.as_strided( - tensors, [evil_length], [1], 2305843009213693952 - evil_length) - index: list[int] = [-1] - for i in range(evil_length - 1 - 5, -1, -1): - if (leak[i + 3].item() == leak[i + 5].item() - and leak[i + 3].item() >> 40 == 127): - if (leak[i + 2].item() >= 0x100000000 - and leak[i + 2].item() <= 0x100000005): - index.append(i) - return (index, leak) - - def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: - tmp: list[Tensor] = [] - heap_len = 200 - array_len = 0x900 - for i in range(heap_len): - tmp_values = [0xdead00000 + i] + [ - 0xdead00000 + j for j in range(2, array_len)] - tmp.append(torch.tensor(tmp_values)) - for i in range(heap_len): - indexs, leak = self.find_heap_overlap(tmp[i]) - indexs = indexs[1:] - if indexs == []: - continue - for index in indexs: - heap_addr = leak[index + 3].item() - libtorch_python_addr = leak[index + 1].item() - libtorch_cpu_addr = libtorch_python_addr - 0xc878b50 - assert 1 == 0, "leak: " + hex(libtorch_cpu_addr) - for i in range(heap_len): - print(tmp[i].shape) - return torch.zeros(0) - - -class RopExploitModel(torch.nn.Module): - """Builds a ROP chain to mprotect heap memory and execute shellcode. - - Extends the heap-leak technique to overwrite function pointers and - chain gadgets that call mprotect followed by attacker-controlled - shellcode that writes files to disk. Key unrecognised ops: - aten::as_strided, aten::copy_, aten::Delete. - """ - def find_heap_overlap(self, tensors: Tensor) -> tuple[list[int], Tensor]: - evil_length = 0x900 * 8 - leak = torch.as_strided( - tensors, [evil_length], [1], 2305843009213693952 - evil_length) - index: list[int] = [-1] - for i in range(evil_length - 1 - 5, -1, -1): - if (leak[i + 3].item() == leak[i + 5].item() - and leak[i + 3].item() >> 40 == 127): - if (leak[i + 2].item() >= 0x100000000 - and leak[i + 2].item() <= 0x100000005): - index.append(i) - return (index, leak) - - def get_address_data(self, leak: Tensor, index: int, - tmp: list[Tensor], target_addr: int) -> int: - leak[index] = target_addr - return int(tmp[0].item()) - - def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: - tmp: list[Tensor] = [] - heap_len = 200 - array_len = 0x900 - for i in range(heap_len): - tmp_values = [0xdead00000 + i] + [ - 0xdead00000 + j for j in range(2, array_len)] - tmp.append(torch.tensor(tmp_values)) - for i in range(heap_len): - indexs, leak = self.find_heap_overlap(tmp[i]) - indexs = indexs[1:] - if indexs == []: - continue - for index in indexs: - heap_addr = leak[index + 3].item() - libtorch_python_addr = leak[index + 1].item() - libtorch_cpu_addr = libtorch_python_addr - 0xc878b50 - leak[index + 3] = int(heap_addr - 0x8) - for j in range(heap_len): - if tmp[j][2].item() == 0xDEAD00002 and i != j: - mprotect_got_addr = 0xc9785f0 + libtorch_cpu_addr - self.get_address_data( - leak, index + 3, tmp[j:j + 1], - int(mprotect_got_addr)) - del tmp[j] - assert 1 == 0, "exploit" - return torch.zeros(0) - for i in range(heap_len): - print(tmp[i].shape) - return torch.zeros(0) - - -# --- Generation logic --- - - -MODELS = { - "malicious_file_reader.pt": FileReaderModel, - "malicious_mixed_file_reader.pt": MixedFileReaderModel, - "malicious_hidden_in_submodule.pt": HiddenInSubmodule, - "malicious_conditional.pt": ConditionalMalicious, - "malicious_many_unrecognised.pt": ManyUnrecognisedOps, - "malicious_file_reader_in_submodule.pt": FileReaderInSubmodule, - "malicious_heap_leak.pt": HeapLeakModel, - "malicious_rop_exploit.pt": RopExploitModel, -} - - -def generate(output_dir: Path): - output_dir.mkdir(parents=True, exist_ok=True) - succeeded = [] - failed = [] - - for filename, cls in MODELS.items(): - print(f" {filename}...", end=" ") - try: - model = cls() - model.eval() - scripted = torch.jit.script(model) - path = output_dir / filename - torch.jit.save(scripted, str(path)) - size = path.stat().st_size - print(f"OK ({size} bytes)") - - # Show ops for verification - graph = scripted.forward.graph.copy() - torch._C._jit_pass_inline(graph) - ops = sorted(set(n.kind() for n in graph.nodes())) - print(f" ops: {ops}") - - succeeded.append(filename) - except Exception as exc: - print(f"FAILED: {exc}") - failed.append((filename, str(exc))) - - print(f"\nGenerated {len(succeeded)}/{len(MODELS)} models") - if failed: - print("Failed:") - for name, err in failed: - print(f" {name}: {err}") - return len(failed) == 0 - - -if __name__ == "__main__": - out_dir = (Path(sys.argv[1]) if len(sys.argv) > 1 - else Path(__file__).resolve().parent.parent - / "bin" / "pytorch_inference" / "unittest" / "testfiles" / "malicious_models") - print(f"Generating malicious model fixtures in {out_dir}") - success = generate(out_dir) - sys.exit(0 if success else 1) diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index 15e49d52a..916d929bc 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -36,7 +36,6 @@ === Enhancements -* Harden pytorch_inference with TorchScript model graph validation. (See {ml-pull}2936[#2936].) * Better handling of invalid JSON state documents (See {ml-pull}[]#2895].) * Better error handling regarding quantiles state documents (See {ml-pull}[#2894]) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3f5cbc9a0..5e571c729 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -57,14 +57,6 @@ else() set(_build_type_arg "") endif() -# Common arguments for the pytorch_inference allowlist validation script. -set(_validation_args - -DSOURCE_DIR=${CMAKE_SOURCE_DIR} - -DVALIDATE_CONFIG=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/validation_models.json - -DVALIDATE_PT_DIR=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/es_it_models - -DVALIDATE_VERBOSE=TRUE -) - add_custom_target(test_all_parallel DEPENDS build_tests COMMAND ${CMAKE_COMMAND} @@ -72,22 +64,5 @@ add_custom_target(test_all_parallel -DBUILD_DIR=${CMAKE_BINARY_DIR} ${_build_type_arg} -P ${CMAKE_SOURCE_DIR}/cmake/run-all-tests-parallel.cmake - COMMAND ${CMAKE_COMMAND} - ${_validation_args} - -DOPTIONAL=TRUE - -P ${CMAKE_SOURCE_DIR}/cmake/run-validation.cmake - WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} -) - -# Standalone target for the pytorch_inference allowlist validation. -# Unlike the invocation inside test_all_parallel (which uses OPTIONAL=TRUE -# to skip gracefully when Python or network access is unavailable), this -# target treats failures as hard errors — use it to explicitly verify the -# allowlist. See dev-tools/extract_model_ops/README.md for details. -add_custom_target(validate_pytorch_inference_models - COMMAND ${CMAKE_COMMAND} - ${_validation_args} - -P ${CMAKE_SOURCE_DIR}/cmake/run-validation.cmake - COMMENT "Validating pytorch_inference allowlist against HuggingFace models and ES integration test models" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} )