Skip to content

Commit 1264745

Browse files
[Misc] Small fixes to Torch code (#395)
Co-authored-by: Brayden Zhong <[email protected]>
1 parent 298f74f commit 1264745

File tree

5 files changed

+17
-18
lines changed

5 files changed

+17
-18
lines changed

csrc/sliding_tile_attention/test/test_sta.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from flex_sta_ref import get_sliding_tile_attention_mask
33
from st_attn import sliding_tile_attention
44
from torch.nn.attention.flex_attention import flex_attention
5-
# from flash_attn_interface import flash_attn_func
65
from tqdm import tqdm
76

87
flex_attention = torch.compile(flex_attention, dynamic=False)
@@ -23,7 +22,7 @@ def h100_fwd_kernel_test(Q, K, V, kernel_size):
2322
def generate_tensor(shape, mean, std, dtype, device):
2423
tensor = torch.randn(shape, dtype=dtype, device=device)
2524

26-
magnitude = torch.norm(tensor, dim=-1, keepdim=True)
25+
magnitude = torch.linalg.norm(tensor, dim=-1, keepdim=True)
2726
scaled_tensor = tensor * (torch.randn(magnitude.shape, dtype=dtype, device=device) * std + mean) / magnitude
2827

2928
return scaled_tensor.contiguous()

fastvideo/models/hunyuan/idle_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def add_inference_args(parser: argparse.ArgumentParser):
237237
type=str,
238238
default="540p",
239239
choices=["540p", "720p"],
240-
help="Root path of all the models, including t2v models and extra models.",
240+
help="The resolution of the model.",
241241
)
242242
group.add_argument(
243243
"--load-key",
@@ -361,7 +361,7 @@ def add_parallel_args(parser: argparse.ArgumentParser):
361361
"--ring-degree",
362362
type=int,
363363
default=1,
364-
help="Ulysses degree.",
364+
help="Ring degree.",
365365
)
366366

367367
return parser

fastvideo/models/hunyuan/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from fastvideo.utils.parallel_states import nccl_info
1818

1919

20-
class Inference(object):
20+
class Inference:
2121

2222
def __init__(
2323
self,

fastvideo/models/hunyuan/prompt_rewrite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_rewrite_prompt(ori_prompt, mode="Normal"):
4141
elif mode == "Master":
4242
prompt = master_mode_prompt.format(input=ori_prompt)
4343
else:
44-
raise Exception("Only supports Normal and Normal", mode)
44+
raise Exception("Only supports Normal and Master mode, but got {}".format(mode))
4545
return prompt
4646

4747

fastvideo/models/stepvideo/text_encoder/stepllm.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -267,25 +267,25 @@ def forward(
267267
class STEP1TextEncoder(torch.nn.Module):
268268

269269
def __init__(self, model_dir, max_length=320):
270-
super(STEP1TextEncoder, self).__init__()
270+
super()
271271
self.max_length = max_length
272272
self.text_tokenizer = Wrapped_StepChatTokenizer(os.path.join(model_dir, 'step1_chat_tokenizer.model'))
273273
text_encoder = Step1Model.from_pretrained(model_dir)
274274
self.text_encoder = text_encoder.eval().to(torch.bfloat16)
275275

276276
@torch.no_grad
277+
@torch.autocast(device_type='cuda', dtype=torch.bfloat16)
277278
def forward(self, prompts, with_mask=True, max_length=None):
278279
self.device = next(self.text_encoder.parameters()).device
279-
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
280-
if type(prompts) is str:
281-
prompts = [prompts]
282-
283-
txt_tokens = self.text_tokenizer(prompts,
284-
max_length=max_length or self.max_length,
285-
padding="max_length",
286-
truncation=True,
287-
return_tensors="pt")
288-
y = self.text_encoder(txt_tokens.input_ids.to(self.device),
280+
if type(prompts) is str:
281+
prompts = [prompts]
282+
283+
txt_tokens = self.text_tokenizer(prompts,
284+
max_length=max_length or self.max_length,
285+
padding="max_length",
286+
truncation=True,
287+
return_tensors="pt")
288+
y = self.text_encoder(txt_tokens.input_ids.to(self.device),
289289
attention_mask=txt_tokens.attention_mask.to(self.device) if with_mask else None)
290-
y_mask = txt_tokens.attention_mask
290+
y_mask = txt_tokens.attention_mask
291291
return y.transpose(0, 1), y_mask

0 commit comments

Comments
 (0)