Skip to content

Commit 298f74f

Browse files
authored
Set device for encode (#420)
1 parent ee8babb commit 298f74f

File tree

3 files changed

+7
-1
lines changed

3 files changed

+7
-1
lines changed

fastvideo/v1/models/vaes/common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def __init__(self, config: VAEConfig, **kwargs) -> None:
3939
self.use_temporal_tiling = config.use_temporal_tiling
4040
self.use_parallel_tiling = config.use_parallel_tiling
4141

42+
def to(self, device) -> 'ParallelTiledVAE':
43+
return self
44+
4245
@property
4346
def temporal_compression_ratio(self) -> int:
4447
return cast(int, self.config.temporal_compression_ratio)

fastvideo/v1/pipelines/stages/decoding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from fastvideo.v1.fastvideo_args import FastVideoArgs
99
from fastvideo.v1.logger import init_logger
10+
from fastvideo.v1.models.vaes.common import ParallelTiledVAE
1011
from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch
1112
from fastvideo.v1.pipelines.stages.base import PipelineStage
1213
from fastvideo.v1.utils import PRECISION_TO_TYPE
@@ -23,7 +24,7 @@ class DecodingStage(PipelineStage):
2324
"""
2425

2526
def __init__(self, vae) -> None:
26-
self.vae = vae
27+
self.vae: ParallelTiledVAE = vae
2728

2829
def forward(
2930
self,

fastvideo/v1/pipelines/stages/encoding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def forward(
4646
Returns:
4747
The batch with encoded outputs.
4848
"""
49+
self.vae = self.vae.to(fastvideo_args.device)
50+
4951
image_path = batch.image_path
5052
# TODO(will): remove this once we add input/output validation for stages
5153
if image_path is None:

0 commit comments

Comments
 (0)