Skip to content

Add example for running inference locally #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,42 @@ TLS can be enabled in the TGIS containers via the following env vars:

These paths can reference mounted secrets containing the certs.

### Local inference

These steps explain how to run inference outside of Docker, connecting to the TGIS server from a Python script running locally.

#### Prepare the GRPC client

Install GRPC: `pip install grpcio grpcio-tools`

In the repository root, run:

```
python -m grpc_tools.protoc -Iproto --python_out=pb --pyi_out=pb --grpc_python_out=pb proto/generate.proto
python -m grpc_tools.protoc -Iproto --python_out=pb --pyi_out=pb --grpc_python_out=pb proto/generation.proto
```

This generates the necessary files in the pb directory. This only needs to be done once.

#### Run the server

Run text-generation launcher. For example, if we have a model named "local_model" in $PWD/data:

```
volume=$PWD/data
MODEL=/data/local_model
IMAGE_ID=your_image_id
docker run -p 8033:8033 -p 3000:3000 -v $volume:/data $IMAGE_ID text-generation-launcher --model-name $MODEL```
```

#### Run inference

In a separate shell (with the environment where you installed GRPC), run:

```
python pb/client.py
```

### Metrics

Prometheus metrics are exposed on the same port as the health probe endpoint (default 3000), at `/metrics`.
Expand Down
21 changes: 21 additions & 0 deletions pb/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import generation_pb2 as pb2
import generation_pb2_grpc as gpb2
import grpc
from google.protobuf import json_format

port = 8033
channel = grpc.insecure_channel(f"localhost:{port}")
stub = gpb2.GenerationServiceStub(channel)

# optional: parameters for inference
params = pb2.Parameters(
method="GREEDY", stopping=pb2.StoppingCriteria(min_new_tokens=20, max_new_tokens=40)
)

prompt = "The weather is"

message = json_format.ParseDict(
{"requests": [{"text": prompt}]}, pb2.BatchedGenerationRequest(params=params)
)
response = stub.Generate(message)
print(prompt, response)