Skip to content

Commit 778918e

Browse files
authored
[Enhancement] Pass model_init_kwargs to sparse model auto tracing to load remote model class (#555)
* add trust remote code Signed-off-by: zhichao-aws <[email protected]> * pass model kwargs Signed-off-by: zhichao-aws <[email protected]> * sanitize, lint Signed-off-by: zhichao-aws <[email protected]> * add ut Signed-off-by: zhichao-aws <[email protected]> * changelog Signed-off-by: zhichao-aws <[email protected]> * fix json string in GH action Signed-off-by: zhichao-aws <[email protected]> --------- Signed-off-by: zhichao-aws <[email protected]>
1 parent 899cb82 commit 778918e

File tree

7 files changed

+159
-12
lines changed

7 files changed

+159
-12
lines changed

.ci/run-repository.sh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,21 @@ elif [[ "$TASK_TYPE" == "SentenceTransformerTrace" || "$TASK_TYPE" == "SparseTra
7777
echo -e "\033[34;1mINFO:\033[0m ACTIVATION: ${ACTIVATION:-N/A}\033[0m"
7878
echo -e "\033[34;1mINFO:\033[0m MODEL_DESCRIPTION: ${MODEL_DESCRIPTION:-N/A}\033[0m"
7979
echo -e "\033[34;1mINFO:\033[0m MODEL_NAME: ${MODEL_NAME:-N/A}\033[0m"
80+
echo -e "\033[34;1mINFO:\033[0m MODEL_INIT_KWARGS: ${MODEL_INIT_KWARGS:-{}}\033[0m"
8081

8182
if [[ "$TASK_TYPE" == "SentenceTransformerTrace" ]]; then
8283
NOX_TRACE_TYPE="trace"
83-
EXTRA_ARGS="-ed ${EMBEDDING_DIMENSION} -pm ${POOLING_MODE}"
84+
EXTRA_ARGS=( -ed "${EMBEDDING_DIMENSION}" -pm "${POOLING_MODE}" )
8485
elif [[ "$TASK_TYPE" == "SparseTrace" ]]; then
8586
NOX_TRACE_TYPE="sparsetrace"
86-
EXTRA_ARGS="-spr ${SPARSE_PRUNE_RATIO} -act ${ACTIVATION}"
87+
EXTRA_ARGS=( -spr "${SPARSE_PRUNE_RATIO}" -act "${ACTIVATION}" -mik "${MODEL_INIT_KWARGS}" )
8788
elif [[ "$TASK_TYPE" == "SparseTokenizerTrace" ]]; then
8889
NOX_TRACE_TYPE="sparsetrace"
89-
# use extra args to trigger the tokenizer tracing logics
90-
EXTRA_ARGS="-t"
90+
# use extra args to trigger the tokenizer tracing logics (no -mik for tokenizer)
91+
EXTRA_ARGS=( -t )
9192
elif [[ "$TASK_TYPE" == "SemanticHighlighterTrace" ]]; then
9293
NOX_TRACE_TYPE="semantic_highlighter_trace"
93-
EXTRA_ARGS=""
94+
EXTRA_ARGS=()
9495
else
9596
echo "Unknown TASK_TYPE: $TASK_TYPE"
9697
exit 1
@@ -105,7 +106,7 @@ elif [[ "$TASK_TYPE" == "SentenceTransformerTrace" || "$TASK_TYPE" == "SparseTra
105106
-up "${UPLOAD_PREFIX}"
106107
-mn "${MODEL_NAME}"
107108
-md "${MODEL_DESCRIPTION:+"$MODEL_DESCRIPTION"}"
108-
${EXTRA_ARGS}
109+
"${EXTRA_ARGS[@]}"
109110
)
110111

111112
echo "nox -s ${nox_command[@]}"

.github/workflows/model_uploader.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ on:
6767
"sparse_prune_ratio": (Optional) Float. Specifies the model-side prune ratio based on max values. Sparse model only.
6868
"activation": (Optional) String. Specifies the activation function for the sparse model. Sparse model only.
6969
"model_name": (Optional) String. Specifies the model name for uploading. Example: transforms "sentence-transformers/model" to "sentence-transformers/{model_name}",
70+
"model_init_kwargs": (Optional) Object. JSON object to pass to from_pretrained via **kwargs.
7071
}
7172
7273
Example:
@@ -102,23 +103,27 @@ jobs:
102103
sparse_prune_ratio=0
103104
activation=""
104105
model_name="${model_id##*/}"
106+
model_init_kwargs="{}"
105107
106108
if [ "$custom_params" != "{}" ] && [ -n "$custom_params" ]; then
107109
tmp_up=$(echo "$custom_params" | jq -r '.upload_prefix | select(.!=null)')
108110
tmp_spr=$(echo "$custom_params" | jq -r '.sparse_prune_ratio | select(.!=null)')
109111
tmp_act=$(echo "$custom_params" | jq -r '.activation | select(.!=null)')
110112
tmp_mn=$(echo "$custom_params" | jq -r '.model_name | select(.!=null)')
113+
tmp_mik=$(echo "$custom_params" | jq -c '.model_init_kwargs | select(.!=null)')
111114
112115
[ -n "$tmp_up" ] && upload_prefix="$tmp_up"
113116
[ -n "$tmp_spr" ] && sparse_prune_ratio="$tmp_spr"
114117
[ -n "$tmp_act" ] && activation="$tmp_act"
115118
[ -n "$tmp_mn" ] && model_name="$tmp_mn"
119+
[ -n "$tmp_mik" ] && model_init_kwargs="$tmp_mik"
116120
fi
117121
118122
echo "upload_prefix=$upload_prefix" >> $GITHUB_OUTPUT
119123
echo "sparse_prune_ratio=$sparse_prune_ratio" >> $GITHUB_OUTPUT
120124
echo "activation=$activation" >> $GITHUB_OUTPUT
121125
echo "model_name=$model_name" >> $GITHUB_OUTPUT
126+
echo "model_init_kwargs=$model_init_kwargs" >> $GITHUB_OUTPUT
122127
- name: Initiate folders
123128
# This scripts init the folders path variables.
124129
# 1. Retrieves the input model_id.
@@ -167,6 +172,7 @@ jobs:
167172
- Model Prefix Folder: ${{ steps.init_folders.outputs.model_prefix_folder }}
168173
- Sparse Prune Ratio: ${{ steps.parse_custom_params.outputs.sparse_prune_ratio || 'N/A' }}
169174
- Activation: ${{ steps.parse_custom_params.outputs.activation || 'N/A' }}
175+
- Model Init Kwargs: '${{ toJSON(fromJSON(steps.parse_custom_params.outputs.model_init_kwargs || '{}')) }}'
170176
171177
======== Workflow Output Information =========
172178
- Embedding Verification: Passed"
@@ -190,6 +196,7 @@ jobs:
190196
sparse_prune_ratio: ${{ steps.parse_custom_params.outputs.sparse_prune_ratio }}
191197
activation: ${{ steps.parse_custom_params.outputs.activation }}
192198
model_name: ${{ steps.parse_custom_params.outputs.model_name }}
199+
model_init_kwargs: ${{ steps.parse_custom_params.outputs.model_init_kwargs }}
193200

194201
# Step 3: Check if the model already exists in the model hub
195202
checking-out-model-hub:
@@ -267,6 +274,9 @@ jobs:
267274
echo "MODEL_NAME=${{ needs.init-workflow-var.outputs.model_name }}" >> $GITHUB_ENV
268275
echo "SPARSE_PRUNE_RATIO=${{ needs.init-workflow-var.outputs.sparse_prune_ratio }}" >> $GITHUB_ENV
269276
echo "ACTIVATION=${{ needs.init-workflow-var.outputs.activation }}" >> $GITHUB_ENV
277+
echo "MODEL_INIT_KWARGS<<EOF" >> $GITHUB_ENV
278+
echo '${{ toJSON(fromJSON(needs.init-workflow-var.outputs.model_init_kwargs)) }}' >> $GITHUB_ENV
279+
echo "EOF" >> $GITHUB_ENV
270280
- name: Autotracing ${{ matrix.cluster }} secured=${{ matrix.secured }} version=${{matrix.entry.opensearch_version}}
271281
run: "./.ci/run-tests ${{ matrix.cluster }} ${{ matrix.secured }} ${{ matrix.entry.opensearch_version }} ${{github.event.inputs.model_type}}Trace"
272282
shell: bash

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
66
### Added
77
- Add space type mapping for sentence transformer models by @nathaliellenaa in ([#512](https://github.com/opensearch-project/opensearch-py-ml/pull/512))
88
- Add example script for deploying semantic highlighter model on aws sagemaker. ([#513](https://github.com/opensearch-project/opensearch-py-ml/pull/513))
9+
- Add model_init_kwargs to sparse model uploading pipeline. ([#555](https://github.com/opensearch-project/opensearch-py-ml/pull/555))
910

1011
### Changed
1112
- Update model upload history - opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill (v.1.0.0)(TORCH_SCRIPT) by @dhrubo-os ([#415](https://github.com/opensearch-project/opensearch-py-ml/pull/415))

opensearch_py_ml/ml_models/sparse_encoding_model.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
# GitHub history for details.
77
import json
88
import os
9+
import re
10+
from typing import Optional
911
from zipfile import ZipFile
1012

1113
import torch
@@ -36,6 +38,30 @@ def _generate_default_model_description() -> str:
3638
return description
3739

3840

41+
def _sanitize_module_name(name: str) -> str:
42+
name = re.sub(r"[^0-9A-Za-z_\.]", "_", name)
43+
parts = []
44+
for p in name.split("."):
45+
if not p:
46+
continue
47+
if p[0].isdigit():
48+
p = f"n_{p}"
49+
parts.append(p)
50+
return ".".join(parts)
51+
52+
53+
def sanitize_model_modules(model: torch.nn.Module) -> None:
54+
seen: set[type] = set()
55+
for m in model.modules():
56+
cls = m.__class__
57+
if cls in seen:
58+
continue
59+
safe = _sanitize_module_name(getattr(cls, "__module__", ""))
60+
if safe and safe != cls.__module__:
61+
cls.__module__ = safe
62+
seen.add(cls)
63+
64+
3965
class SparseEncodingModel(SparseModel):
4066
"""
4167
Class for exporting and configuring the NeuralSparseV2Model model.
@@ -50,12 +76,15 @@ def __init__(
5076
overwrite: bool = False,
5177
sparse_prune_ratio: float = 0,
5278
activation: str = None,
79+
model_init_kwargs: Optional[dict] = None,
5380
) -> None:
5481

5582
super().__init__(model_id, folder_path, overwrite)
56-
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
83+
if model_init_kwargs is None:
84+
model_init_kwargs = {}
85+
self.tokenizer = AutoTokenizer.from_pretrained(model_id, **model_init_kwargs)
5786
self.backbone_model = AutoModelForMaskedLM.from_pretrained(
58-
model_id, _attn_implementation="eager"
87+
model_id, _attn_implementation="eager", **model_init_kwargs
5988
)
6089
default_folder_path = os.path.join(
6190
os.getcwd(), "opensearch_neural_sparse_model_files"
@@ -167,6 +196,7 @@ def save_as_pt(
167196
return_tensors="pt",
168197
).to(device)
169198

199+
sanitize_model_modules(cpu_model)
170200
compiled_model = torch.jit.trace(cpu_model, dict(features), strict=False)
171201
torch.jit.save(compiled_model, model_path)
172202
print("model file is saved to ", model_path)

tests/ml_models/test_sparseencondingmodel_pytest.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
from zipfile import ZipFile
1414

1515
import pytest
16+
from torch import nn
1617

1718
from opensearch_py_ml.ml_models import SparseEncodingModel
19+
from opensearch_py_ml.ml_models.sparse_encoding_model import sanitize_model_modules
20+
from utils.model_uploader.autotracing_utils import init_sparse_model
1821

1922
TEST_FOLDER = os.path.join(
2023
os.path.dirname(os.path.abspath("__file__")), "tests", "test_model_files"
@@ -374,5 +377,69 @@ def test_process_sparse_encoding():
374377
check_value(1.0706572532653809, encoding_result[1]["hello"], 0.001)
375378

376379

380+
def test_sanitize_module_name_and_trace():
381+
class WeirdSub(nn.Module):
382+
def __init__(self):
383+
super().__init__()
384+
385+
def forward(self, input_ids=None, attention_mask=None):
386+
pass
387+
388+
# simulate weird remote module path
389+
WeirdSub.__module__ = "remote.repo@bad:name/1"
390+
391+
class Toy(nn.Module):
392+
def __init__(self):
393+
super().__init__()
394+
self.m = WeirdSub()
395+
396+
def forward(self, features: dict):
397+
pass
398+
399+
model = Toy().eval()
400+
sanitize_model_modules(model)
401+
402+
# After sanitize, module name should only contain [0-9A-Za-z_.]
403+
assert all(c.isalnum() or c in {"_", "."} for c in model.m.__class__.__module__)
404+
405+
406+
def test_init_sparse_model_kwargs_passthrough():
407+
received = {}
408+
409+
class FakeModel:
410+
def __init__(
411+
self,
412+
model_id,
413+
folder_path,
414+
overwrite,
415+
sparse_prune_ratio,
416+
activation,
417+
model_init_kwargs,
418+
):
419+
received["model_id"] = model_id
420+
received["folder_path"] = folder_path
421+
received["overwrite"] = overwrite
422+
received["sparse_prune_ratio"] = sparse_prune_ratio
423+
received["activation"] = activation
424+
received["model_init_kwargs"] = model_init_kwargs
425+
426+
model = init_sparse_model(
427+
FakeModel,
428+
model_id="foo/bar",
429+
folder_path="/tmp/xyz",
430+
sparse_prune_ratio=0.2,
431+
activation="l0",
432+
model_init_kwargs={"trust_remote_code": True, "revision": "dev"},
433+
)
434+
435+
assert isinstance(model, FakeModel)
436+
assert received["model_id"] == "foo/bar"
437+
assert received["folder_path"] == "/tmp/xyz"
438+
assert received["overwrite"] is True
439+
assert received["sparse_prune_ratio"] == 0.2
440+
assert received["activation"] == "l0"
441+
assert received["model_init_kwargs"]["trust_remote_code"] is True
442+
443+
377444
clean_test_folder(TEST_FOLDER)
378445
clean_test_folder(TESTDATA_UNZIP_FOLDER)

utils/model_uploader/autotracing_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import shutil
1010
import warnings
11-
from typing import Type, TypeVar
11+
from typing import Any, Dict, Optional, Type, TypeVar
1212

1313
from huggingface_hub import HfApi
1414

@@ -230,15 +230,23 @@ def __init__(self, stage: str, model_format: str, original_exception: Exception)
230230

231231

232232
def init_sparse_model(
233-
model_class: Type[T], model_id, folder_path, sparse_prune_ratio=0, activation=None
233+
model_class: Type[T],
234+
model_id,
235+
folder_path,
236+
sparse_prune_ratio=0,
237+
activation=None,
238+
model_init_kwargs: Optional[Dict[str, Any]] = None,
234239
) -> T:
235240
try:
241+
if model_init_kwargs is None:
242+
model_init_kwargs = {}
236243
pre_trained_model = model_class(
237244
model_id=model_id,
238245
folder_path=folder_path,
239246
overwrite=True,
240247
sparse_prune_ratio=sparse_prune_ratio,
241248
activation=activation,
249+
model_init_kwargs=model_init_kwargs,
242250
)
243251
except Exception as e:
244252
raise ModelTraceError("initiating a sparse encoding model class object", e)

0 commit comments

Comments
 (0)