Skip to content

Commit ec9f13d

Browse files
committed
misc review comments
Signed-off-by: Sage Moore <[email protected]>
1 parent 49cdc3d commit ec9f13d

File tree

3 files changed

+2
-13
lines changed

3 files changed

+2
-13
lines changed

vllm/compilation/ubatch_wrapper.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def __init__(self, runnable: Callable, vllm_config: VllmConfig,
4444
self.runnable = runnable
4545
self.vllm_config = vllm_config
4646
self.compilation_config = vllm_config.compilation_config
47-
self.comm_stream = torch.cuda.Stream()
48-
self.device = device
47+
self.comm_stream = torch.cuda.Stream(device=device)
4948
self.ready_barrier = threading.Barrier(3)
5049

5150
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
@@ -204,8 +203,7 @@ def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids,
204203
comm_stream=self.comm_stream,
205204
compute_stream=compute_stream,
206205
forward_contexts=forward_contexts,
207-
ready_barrier=self.ready_barrier,
208-
device=self.device)
206+
ready_barrier=self.ready_barrier)
209207

210208
ubatch_metadata: list[UbatchMetadata] = []
211209
for i, ubatch_slice in enumerate(ubatch_slices):

vllm/v1/worker/gpu_model_runner.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2494,12 +2494,6 @@ def _dummy_run(
24942494
# We only support decode-only cudagraphs
24952495
assert num_reqs == num_tokens
24962496
assert num_tokens % 2 == 0
2497-
# num_tokens_per_ubatch = num_tokens // 2
2498-
# dp_size = self.vllm_config.parallel_config.data_parallel_size
2499-
# num_tokens_across_dp = torch.tensor([num_tokens_per_ubatch] *
2500-
# dp_size,
2501-
# device="cpu",
2502-
# dtype=torch.int32)
25032497
ubatch_slices = [
25042498
UbatchSlice(slice(0, num_reqs // 2), slice(0,
25052499
num_tokens // 2)),

vllm/v1/worker/ubatching.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ def make_ubatch_contexts(
172172
comm_stream: torch.cuda.Stream,
173173
forward_contexts: list[ForwardContext],
174174
ready_barrier: threading.Barrier,
175-
device: Optional[torch.device] = None,
176175
schedule: str = "default",
177176
) -> list[UBatchContext]:
178177
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
@@ -186,8 +185,6 @@ def make_ubatch_contexts(
186185
gpu_compute_done_events = [
187186
torch.cuda.Event() for _ in range(num_micro_batches)
188187
]
189-
device = device or torch.cuda.current_device()
190-
# comm_stream = torch.cuda.Stream(device)
191188

192189
assert len(forward_contexts) == 2
193190

0 commit comments

Comments
 (0)