diff --git a/README.md b/README.md index 231273cf..862b43d2 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,42 @@ TLS can be enabled in the TGIS containers via the following env vars: These paths can reference mounted secrets containing the certs. +### Local inference + +These steps explain how to run inference outside of Docker, connecting to the TGIS server from a Python script running locally. + +#### Prepare the GRPC client + +Install GRPC: `pip install grpcio grpcio-tools` + +In the repository root, run: + +``` +python -m grpc_tools.protoc -Iproto --python_out=pb --pyi_out=pb --grpc_python_out=pb proto/generate.proto +python -m grpc_tools.protoc -Iproto --python_out=pb --pyi_out=pb --grpc_python_out=pb proto/generation.proto +``` + +This generates the necessary files in the pb directory. This only needs to be done once. + +#### Run the server + +Run text-generation launcher. For example, if we have a model named "local_model" in $PWD/data: + +``` +volume=$PWD/data +MODEL=/data/local_model +IMAGE_ID=your_image_id +docker run -p 8033:8033 -p 3000:3000 -v $volume:/data $IMAGE_ID text-generation-launcher --model-name $MODEL``` +``` + +#### Run inference + +In a separate shell (with the environment where you installed GRPC), run: + +``` +python pb/client.py +``` + ### Metrics Prometheus metrics are exposed on the same port as the health probe endpoint (default 3000), at `/metrics`. diff --git a/pb/client.py b/pb/client.py new file mode 100644 index 00000000..23a4d5dc --- /dev/null +++ b/pb/client.py @@ -0,0 +1,21 @@ +import generation_pb2 as pb2 +import generation_pb2_grpc as gpb2 +import grpc +from google.protobuf import json_format + +port = 8033 +channel = grpc.insecure_channel(f"localhost:{port}") +stub = gpb2.GenerationServiceStub(channel) + +# optional: parameters for inference +params = pb2.Parameters( + method="GREEDY", stopping=pb2.StoppingCriteria(min_new_tokens=20, max_new_tokens=40) +) + +prompt = "The weather is" + +message = json_format.ParseDict( + {"requests": [{"text": prompt}]}, pb2.BatchedGenerationRequest(params=params) +) +response = stub.Generate(message) +print(prompt, response)