@@ -787,10 +787,10 @@ def compute_attn_mask_seqlen(
787787 def forward (
788788 self ,
789789 x : torch .Tensor ,
790- grid_thw : list [list [int ]],
790+ grid_thw : torch . Tensor | list [list [int ]],
791791 ) -> torch .Tensor :
792- # Convert grid_thw to tensor (always expecting list format now)
793- grid_thw = torch .tensor (grid_thw , device = x . device , dtype = torch .long )
792+ if isinstance ( grid_thw , list ):
793+ grid_thw = torch .tensor (grid_thw , dtype = torch .int32 )
794794
795795 # patchify
796796 x = x .to (device = self .device , dtype = self .dtype )
@@ -805,7 +805,8 @@ def forward(
805805 cu_seqlens = torch .repeat_interleave (
806806 grid_thw [:, 1 ] * grid_thw [:, 2 ], grid_thw [:, 0 ]
807807 ).cumsum (dim = 0 , dtype = torch .int32 )
808- cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), "constant" , 0 )
808+ cu_seqlens = torch .cat ([cu_seqlens .new_zeros (1 ), cu_seqlens ])
809+ cu_seqlens = cu_seqlens .to (self .device , non_blocking = True )
809810
810811 # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
811812 max_seqlen = self .compute_attn_mask_seqlen (cu_seqlens )
@@ -1548,7 +1549,6 @@ def _process_image_input(
15481549 ) -> tuple [torch .Tensor , ...]:
15491550 grid_thw = image_input ["image_grid_thw" ]
15501551 assert grid_thw .ndim == 2
1551- grid_thw_list = grid_thw .tolist ()
15521552
15531553 if image_input ["type" ] == "image_embeds" :
15541554 image_embeds = image_input ["image_embeds" ].type (self .visual .dtype )
@@ -1559,20 +1559,17 @@ def _process_image_input(
15591559 self .visual , pixel_values , grid_thw .tolist (), rope_type = "rope_3d"
15601560 )
15611561 else :
1562- image_embeds = self .visual (pixel_values , grid_thw = grid_thw .tolist ())
1562+ image_embeds = self .visual (pixel_values , grid_thw = grid_thw )
1563+
15631564 merge_size = self .visual .spatial_merge_size
1564- sizes = (
1565- torch .tensor (grid_thw_list , dtype = torch .long ).prod (- 1 )
1566- // (merge_size * merge_size )
1567- ).tolist ()
1565+ sizes = (grid_thw .prod (- 1 ) // merge_size // merge_size ).tolist ()
15681566 return image_embeds .split (sizes )
15691567
15701568 def _process_video_input (
15711569 self , video_input : Glm4vVideoInputs
15721570 ) -> tuple [torch .Tensor , ...]:
15731571 grid_thw = video_input ["video_grid_thw" ]
15741572 assert grid_thw .ndim == 2
1575- grid_thw_list = grid_thw .tolist ()
15761573
15771574 if video_input ["type" ] == "video_embeds" :
15781575 video_embeds = video_input ["video_embeds" ].type (self .visual .dtype )
@@ -1588,15 +1585,11 @@ def _process_video_input(
15881585 rope_type = "rope_3d" ,
15891586 )
15901587 else :
1591- video_embeds = self .visual (
1592- pixel_values_videos , grid_thw = grid_thw .tolist ()
1593- )
1588+ video_embeds = self .visual (pixel_values_videos , grid_thw = grid_thw )
1589+
15941590 # Split concatenated embeddings for each video item.
15951591 merge_size = self .visual .spatial_merge_size
1596- sizes = (
1597- torch .tensor (grid_thw_list , dtype = torch .long ).prod (- 1 )
1598- // (merge_size * merge_size )
1599- ).tolist ()
1592+ sizes = (grid_thw .prod (- 1 ) // merge_size // merge_size ).tolist ()
16001593 return video_embeds .split (sizes )
16011594
16021595 def _parse_and_validate_multimodal_inputs (self , ** kwargs : object ) -> dict :
0 commit comments