55import traceback
66import logging
77from typing import List , Optional , Tuple , Dict , Any
8+ import threading
89from polygraphy import cuda
910
1011from ..utilities import Engine
@@ -22,6 +23,8 @@ def __init__(self, engine_path: str, stream: 'cuda.Stream', use_cuda_graph: bool
2223 self .stream = stream
2324 self .use_cuda_graph = use_cuda_graph
2425 self .model_type = model_type .lower ()
26+ # Serialize infer calls per engine/context for safety
27+ self ._infer_lock = threading .RLock ()
2528
2629 self .engine .load ()
2730 self .engine .activate ()
@@ -108,15 +111,16 @@ def __call__(self,
108111 output_shapes = self ._resolve_output_shapes (batch_size , latent_height , latent_width )
109112 shape_dict .update (output_shapes )
110113
111- self .engine .allocate_buffers (shape_dict = shape_dict , device = sample .device )
112-
113- outputs = self .engine .infer (
114- input_dict ,
115- self .stream ,
116- use_cuda_graph = self .use_cuda_graph ,
117- )
118-
119- self .stream .synchronize ()
114+ with self ._infer_lock :
115+ self .engine .allocate_buffers (shape_dict = shape_dict , device = sample .device )
116+ outputs = self .engine .infer (
117+ input_dict ,
118+ self .stream ,
119+ use_cuda_graph = self .use_cuda_graph ,
120+ )
121+ # Synchronize to ensure outputs are ready before consumption by other streams
122+ # This preserves correctness when UNet runs on a different CUDA stream.
123+ self .stream .synchronize ()
120124
121125 down_blocks , mid_block = self ._extract_controlnet_outputs (outputs )
122126 return down_blocks , mid_block
0 commit comments