Skip to content

Commit ce6115b

Browse files
fix: add tp_size for sglang_wrapper
1 parent bc3cfc3 commit ce6115b

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

graphgen/models/llm/local/sglang_wrapper.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ def __init__(
3333
self.temperature = temperature
3434
self.top_p = top_p
3535
self.topk = topk
36+
self.tp_size = int(tp_size)
3637

3738
# Initialise the offline engine
38-
self.engine = sgl.Engine(model_path=self.model_path, tp_size=int(tp_size))
39+
self.engine = sgl.Engine(model_path=self.model_path, tp_size=tp_size)
3940

4041
# Keep helpers for streaming
4142
self.async_stream_and_merge = async_stream_and_merge
@@ -146,4 +147,6 @@ def shutdown(self) -> None:
146147
def restart(self) -> None:
147148
"""Restart the SGLang engine."""
148149
self.shutdown()
149-
self.engine = self.engine.__class__(model_path=self.model_path)
150+
self.engine = self.engine.__class__(
151+
model_path=self.model_path, tp_size=self.tp_size
152+
)

0 commit comments

Comments
 (0)