Skip to content

Commit 7bc9cbb

Browse files
author
David Eigen
committed
Merge branch 'master' of github.com:Clarifai/clarifai-python
2 parents df603bd + 821a0bb commit 7bc9cbb

File tree

6 files changed

+51
-15
lines changed

6 files changed

+51
-15
lines changed

clarifai/cli/model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,18 @@ def run_locally(model_path, port, mode, keep_env, keep_image):
171171
click.echo(f"Failed to starts model server locally: {e}", err=True)
172172

173173

174+
@model.command()
175+
@click.option(
176+
'--model_path',
177+
type=click.Path(exists=True),
178+
required=True,
179+
help='Path to the model directory.')
180+
def local_dev(model_path):
181+
"""Run the model as a local dev runner to help debug your model connected to the API. You must set several envvars such as CLARIFAI_PAT, CLARIFAI_RUNNER_ID, CLARIFAI_NODEPOOL_ID, CLARIFAI_COMPUTE_CLUSTER_ID. """
182+
from clarifai.runners.server import serve
183+
serve(model_path)
184+
185+
174186
@model.command()
175187
@click.option(
176188
'--config',

clarifai/runners/models/model_run_locally.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,10 @@ def main(model_path,
482482
sys.exit(1)
483483
manager = ModelRunLocally(model_path)
484484
# get whatever stage is in config.yaml to force download now
485+
# also always write to where upload/build wants to, not the /tmp folder that runtime stage uses
485486
_, _, _, when = manager.builder._validate_config_checkpoints()
486-
manager.builder.download_checkpoints(stage=when)
487+
manager.builder.download_checkpoints(
488+
stage=when, checkpoint_path_override=manager.builder.checkpoint_path)
487489
if inside_container:
488490
if not manager.is_docker_installed():
489491
sys.exit(1)

clarifai/runners/server.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,30 +68,43 @@ def main():
6868

6969
parsed_args = parser.parse_args()
7070

71-
builder = ModelBuilder(parsed_args.model_path, download_validation_only=True)
71+
serve(parsed_args.model_path, parsed_args.port, parsed_args.pool_size,
72+
parsed_args.max_queue_size, parsed_args.max_msg_length, parsed_args.enable_tls,
73+
parsed_args.grpc)
74+
75+
76+
def serve(model_path,
77+
port=8000,
78+
pool_size=32,
79+
max_queue_size=10,
80+
max_msg_length=1024 * 1024 * 1024,
81+
enable_tls=False,
82+
grpc=False):
83+
84+
builder = ModelBuilder(model_path, download_validation_only=True)
7285

7386
model = builder.create_model_instance()
7487

7588
# Setup the grpc server for local development.
76-
if parsed_args.grpc:
89+
if grpc:
7790

7891
# initialize the servicer with the runner so that it gets the predict(), generate(), stream() classes.
7992
servicer = ModelServicer(model)
8093

8194
server = GRPCServer(
8295
futures.ThreadPoolExecutor(
83-
max_workers=parsed_args.pool_size,
96+
max_workers=pool_size,
8497
thread_name_prefix="ServeCalls",
8598
),
86-
parsed_args.max_msg_length,
87-
parsed_args.max_queue_size,
99+
max_msg_length,
100+
max_queue_size,
88101
)
89-
server.add_port_to_server('[::]:%s' % parsed_args.port, parsed_args.enable_tls)
102+
server.add_port_to_server('[::]:%s' % port, enable_tls)
90103

91104
service_pb2_grpc.add_V2Servicer_to_server(servicer, server)
92105
server.start()
93-
logger.info("Started server on port %s", parsed_args.port)
94-
logger.info(f"Access the model at http://localhost:{parsed_args.port}")
106+
logger.info("Started server on port %s", port)
107+
logger.info(f"Access the model at http://localhost:{port}")
95108
server.wait_for_termination()
96109
else: # start the runner with the proper env variables and as a runner protocol.
97110

tests/runners/test_download_checkpoints.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pytest
66

77
from clarifai.runners.models.model_builder import ModelBuilder
8-
from clarifai.runners.utils.const import DEFAULT_RUNTIME_DOWNLOAD_PATH
98
from clarifai.runners.utils.loader import HuggingFaceLoader
109

1110
MODEL_ID = "timm/mobilenetv3_small_100.lamb_in1k"
@@ -62,16 +61,20 @@ def test_download_checkpoints(dummy_runner_models_dir):
6261
model_builder = ModelBuilder(model_folder_path, download_validation_only=True)
6362
# defaults to runtime stage which matches config.yaml not having a when field.
6463
# get whatever stage is in config.yaml to force download now
64+
# also always write to where upload/build wants to, not the /tmp folder that runtime stage uses
6565
_, _, _, when = model_builder._validate_config_checkpoints()
66-
checkpoint_dir = model_builder.download_checkpoints(stage=when)
67-
assert checkpoint_dir == DEFAULT_RUNTIME_DOWNLOAD_PATH
66+
checkpoint_dir = model_builder.download_checkpoints(
67+
stage=when, checkpoint_path_override=model_builder.checkpoint_path)
68+
assert checkpoint_dir == model_builder.checkpoint_path
6869

6970
# This doesn't have when in it's config.yaml so build.
7071
model_folder_path = os.path.join(os.path.dirname(__file__), "hf_mbart_model")
7172
model_builder = ModelBuilder(model_folder_path, download_validation_only=True)
7273
# defaults to runtime stage which matches config.yaml not having a when field.
7374
# get whatever stage is in config.yaml to force download now
75+
# also always write to where upload/build wants to, not the /tmp folder that runtime stage uses
7476
_, _, _, when = model_builder._validate_config_checkpoints()
75-
checkpoint_dir = model_builder.download_checkpoints(stage=when)
77+
checkpoint_dir = model_builder.download_checkpoints(
78+
stage=when, checkpoint_path_override=model_builder.checkpoint_path)
7679
assert checkpoint_dir == os.path.join(
7780
os.path.dirname(__file__), "hf_mbart_model", "1", "checkpoints")

tests/runners/test_model_run_locally-container.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,12 @@ def test_hf_docker_build_and_test_container(hf_model_run_locally):
8181
This test will be skipped if Docker is not installed.
8282
"""
8383

84+
# get whatever stage is in config.yaml to force download now
85+
# also always write to where upload/build wants to, not the /tmp folder that runtime stage uses
8486
# Download the checkpoints for the model
85-
hf_model_run_locally.builder.download_checkpoints()
87+
_, _, _, when = hf_model_run_locally.builder._validate_config_checkpoints()
88+
hf_model_run_locally.builder.download_checkpoints(
89+
stage=when, checkpoint_path_override=hf_model_run_locally.builder.checkpoint_path)
8690

8791
# Test if Docker is installed
8892
assert hf_model_run_locally.is_docker_installed(), "Docker not installed, skipping."

tests/runners/test_model_run_locally.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,10 @@ def test_hf_test_model_success(hf_model_run_locally):
171171
This calls the script's test_model method, which runs a subprocess.
172172
"""
173173
# get whatever stage is in config.yaml to force download now
174+
# also always write to where upload/build wants to, not the /tmp folder that runtime stage uses
174175
_, _, _, when = hf_model_run_locally.builder._validate_config_checkpoints()
175-
hf_model_run_locally.builder.download_checkpoints(stage=when)
176+
hf_model_run_locally.builder.download_checkpoints(
177+
stage=when, checkpoint_path_override=hf_model_run_locally.builder.checkpoint_path)
176178
hf_model_run_locally.create_temp_venv()
177179
hf_model_run_locally.install_requirements()
178180

0 commit comments

Comments
 (0)