Skip to content

Commit 30db272

Browse files
committed
[WIP][Feat] distillation
1 parent b392e6a commit 30db272

File tree

6 files changed

+1661
-9
lines changed

6 files changed

+1661
-9
lines changed

fastvideo/v1/fastvideo_args.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class FastVideoArgs:
3434
# Distributed executor backend
3535
distributed_executor_backend: str = "mp"
3636

37-
inference_mode: bool = True # if False == training mode
37+
mode: str = "inference" # Options: "inference", "training", "distill"
3838

3939
# HuggingFace specific parameters
4040
trust_remote_code: bool = False
@@ -115,7 +115,15 @@ class FastVideoArgs:
115115

116116
@property
117117
def training_mode(self) -> bool:
118-
return not self.inference_mode
118+
return self.mode == "training"
119+
120+
@property
121+
def distill_mode(self) -> bool:
122+
return self.mode == "distill"
123+
124+
@property
125+
def inference_mode(self) -> bool:
126+
return self.mode == "inference"
119127

120128
def __post_init__(self):
121129
pass
@@ -150,10 +158,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
150158
)
151159

152160
parser.add_argument(
153-
"--inference-mode",
154-
action=StoreBoolean,
155-
default=FastVideoArgs.inference_mode,
156-
help="Whether to use inference mode",
161+
"--mode",
162+
type=str,
163+
default=FastVideoArgs.mode,
164+
choices=["inference", "training", "distill"],
165+
help="The mode to use",
157166
)
158167

159168
# HuggingFace specific parameters

fastvideo/v1/pipelines/composed_pipeline_base.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ def __init__(self,
9494
self.initialize_validation_pipeline(self.training_args)
9595
self.initialize_training_pipeline(self.training_args)
9696

97+
if fastvideo_args.distill_mode:
98+
self.initialize_distillation_pipeline(fastvideo_args)
99+
100+
if fastvideo_args.log_validation:
101+
self.initialize_validation_pipeline(fastvideo_args)
102+
97103
self.initialize_pipeline(fastvideo_args)
98104

99105
if not fastvideo_args.training_mode:
@@ -109,6 +115,10 @@ def initialize_validation_pipeline(self, training_args: TrainingArgs):
109115
"if log_validation is True, the pipeline must implement this method"
110116
)
111117

118+
def initialize_distillation_pipeline(self, fastvideo_args: FastVideoArgs):
119+
raise NotImplementedError(
120+
"if distill_mode is True, the pipeline must implement this method")
121+
112122
@classmethod
113123
def from_pretrained(cls,
114124
model_path: str,
@@ -148,7 +158,7 @@ def from_pretrained(cls,
148158
config_args = shallow_asdict(config)
149159
config_args.update(kwargs)
150160

151-
if args is None or args.inference_mode:
161+
if args.mode == "inference":
152162
fastvideo_args = FastVideoArgs(model_path=model_path,
153163
device_str=device or "cuda" if
154164
torch.cuda.is_available() else "cpu",
@@ -172,7 +182,7 @@ def from_pretrained(cls,
172182
fastvideo_args.num_gpus = int(os.environ.get("WORLD_SIZE", 1))
173183
fastvideo_args.use_cpu_offload = False
174184
# make sure we are in training mode
175-
fastvideo_args.inference_mode = False
185+
fastvideo_args.mode = args.mode
176186
# we hijack the precision to be the master weight type so that the
177187
# model is loaded with the correct precision. Subsequently we will
178188
# use FSDP2's MixedPrecisionPolicy to set the precision for the

0 commit comments

Comments
 (0)