Skip to content

Commit bc3cfc3

Browse files
fix: add tp_size for sglang_wrapper
1 parent b84ae9e commit bc3cfc3

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

graphgen/models/llm/local/sglang_wrapper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
temperature: float = 0.0,
1717
top_p: float = 1.0,
1818
topk: int = 5,
19+
tp_size: int = 1,
1920
**kwargs: Any,
2021
):
2122
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
@@ -34,7 +35,7 @@ def __init__(
3435
self.topk = topk
3536

3637
# Initialise the offline engine
37-
self.engine = sgl.Engine(model_path=self.model_path)
38+
self.engine = sgl.Engine(model_path=self.model_path, tp_size=int(tp_size))
3839

3940
# Keep helpers for streaming
4041
self.async_stream_and_merge = async_stream_and_merge

0 commit comments

Comments
 (0)