Skip to content

Commit 51d5e9b

Browse files
mgazzchristian-pintoDarkLight1337
authored
[Core][Model] Terratorch backend integration (vllm-project#23513)
Signed-off-by: Michele Gazzetti <[email protected]> Signed-off-by: Christian Pinto <[email protected]> Co-authored-by: Christian Pinto <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent e7fc700 commit 51d5e9b

File tree

23 files changed

+305
-208
lines changed

23 files changed

+305
-208
lines changed

examples/offline_inference/prithvi_geospatial_mae.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@
4545
class PrithviMAE:
4646
def __init__(self, model):
4747
self.model = LLM(
48-
model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True
48+
model=model,
49+
skip_tokenizer_init=True,
50+
dtype="float16",
51+
enforce_eager=True,
52+
model_impl="terratorch",
4953
)
5054

5155
def run(self, input_data, location_coords):

examples/offline_inference/prithvi_geospatial_mae_io_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def main():
3737
# The maximum number depends on the available GPU memory
3838
max_num_seqs=32,
3939
io_processor_plugin="prithvi_to_tiff_india",
40+
model_impl="terratorch",
4041
)
4142

4243
pooling_params = PoolingParams(task="encode", softmax=False)

examples/online_serving/prithvi_geospatial_mae.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# https://github.com/christian-pinto/prithvi_io_processor_plugin
1616
# - start vllm in serving mode with the below args
1717
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
18+
# --model-impl terratorch
1819
# --task embed --trust-remote-code
1920
# --skip-tokenizer-init --enforce-eager
2021
# --io-processor-plugin prithvi_to_tiff_india

requirements/test.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,5 @@ runai-model-streamer==0.11.0
5353
runai-model-streamer-s3==0.11.0
5454
fastsafetensors>=0.1.10
5555
pydantic>=2.10 # 2.9 leads to error on python 3.10
56-
terratorch==1.1rc2 # required for PrithviMAE test
5756
decord==0.6.0
57+
terratorch==1.1rc3 # required for PrithviMAE test

requirements/test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ tensorboardx==2.6.4
10421042
# via lightning
10431043
tensorizer==2.10.1
10441044
# via -r requirements/test.in
1045-
terratorch==1.1rc2
1045+
terratorch==1.1rc3
10461046
# via -r requirements/test.in
10471047
threadpoolctl==3.5.0
10481048
# via scikit-learn

tests/distributed/test_pipeline_parallel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ def _compare_tp(
298298
tokenizer_mode = model_info.tokenizer_mode
299299
hf_overrides = model_info.hf_overrides
300300
hf_config = get_config(model_id, trust_remote_code)
301+
skip_tokenizer_init = model_info.skip_tokenizer_init
302+
max_num_seqs = model_info.max_num_seqs
301303

302304
dtype = "float16"
303305
if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS:
@@ -351,6 +353,10 @@ def _compare_tp(
351353
common_args.extend(["--load-format", load_format])
352354
if hf_overrides:
353355
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
356+
if skip_tokenizer_init:
357+
common_args.append("--skip-tokenizer-init")
358+
if max_num_seqs:
359+
common_args.extend(["--max-num-seqs", f"{max_num_seqs}"])
354360

355361
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
356362
testing_ray_compiled_graph = False

tests/distributed/test_sequence_parallel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def _compare_sp(
178178
trust_remote_code = model_info.trust_remote_code
179179
tokenizer_mode = model_info.tokenizer_mode
180180
hf_overrides = model_info.hf_overrides
181+
skip_tokenizer_init = model_info.skip_tokenizer_init
181182

182183
if load_format == "dummy":
183184
# Avoid OOM
@@ -227,6 +228,8 @@ def _compare_sp(
227228
common_args.extend(["--load-format", load_format])
228229
if hf_overrides:
229230
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
231+
if skip_tokenizer_init:
232+
common_args.append("--skip-tokenizer-init")
230233

231234
compilation_config = {
232235
'level': 3,

tests/entrypoints/openai/test_chat_template.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
104104
trust_remote_code=model_info.trust_remote_code,
105105
revision=model_info.revision,
106106
hf_overrides=model_info.hf_overrides,
107-
)
107+
skip_tokenizer_init=model_info.skip_tokenizer_init,
108+
enforce_eager=model_info.enforce_eager,
109+
dtype=model_info.dtype)
108110

109111
# Initialize the tokenizer
110112
tokenizer = get_tokenizer(

tests/entrypoints/openai/test_skip_tokenizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from ...utils import RemoteOpenAIServer
1313

14-
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
14+
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
1515
DTYPE = "float16"
1616

1717

@@ -35,7 +35,9 @@ def server():
3535
"--trust-remote-code",
3636
"--skip-tokenizer-init",
3737
"--max-num-seqs",
38-
"32"
38+
"32",
39+
"--model-impl",
40+
"terratorch"
3941
]
4042

4143
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

tests/entrypoints/test_chat_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,9 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
12661266
revision=model_info.revision,
12671267
trust_remote_code=model_info.trust_remote_code,
12681268
hf_overrides=model_info.hf_overrides,
1269-
)
1269+
skip_tokenizer_init=model_info.skip_tokenizer_init,
1270+
enforce_eager=model_info.enforce_eager,
1271+
dtype=model_info.dtype)
12701272

12711273
# Build the tokenizer group and grab the underlying tokenizer
12721274
tokenizer_group = TokenizerGroup(
@@ -1322,7 +1324,9 @@ def test_resolve_content_format_hf_defined(model, expected_format):
13221324
revision=model_info.revision,
13231325
trust_remote_code=model_info.trust_remote_code,
13241326
hf_overrides=model_info.hf_overrides,
1325-
)
1327+
skip_tokenizer_init=model_info.skip_tokenizer_init,
1328+
enforce_eager=model_info.enforce_eager,
1329+
dtype=model_info.dtype)
13261330

13271331
tokenizer_group = TokenizerGroup(
13281332
model,
@@ -1382,7 +1386,9 @@ def test_resolve_content_format_fallbacks(model, expected_format):
13821386
revision=model_info.revision,
13831387
trust_remote_code=model_info.trust_remote_code,
13841388
hf_overrides=model_info.hf_overrides,
1385-
)
1389+
skip_tokenizer_init=model_info.skip_tokenizer_init,
1390+
enforce_eager=model_info.enforce_eager,
1391+
dtype=model_info.dtype)
13861392

13871393
tokenizer_group = TokenizerGroup(
13881394
model_config.tokenizer,

0 commit comments

Comments
 (0)