@@ -18,7 +18,7 @@ def sinusoidal_embedding_1d(dim, position):
1818 # preprocess
1919 assert dim % 2 == 0
2020 half = dim // 2
21- position = position .type (torch .float64 )
21+ position = position .type (torch .float32 )
2222
2323 # calculation
2424 sinusoid = torch .outer (
@@ -353,7 +353,7 @@ def __init__(self,
353353
354354 # embeddings
355355 self .patch_embedding = operations .Conv3d (
356- in_dim , dim , kernel_size = patch_size , stride = patch_size , device = operation_settings .get ("device" ), dtype = operation_settings . get ( "dtype" ) )
356+ in_dim , dim , kernel_size = patch_size , stride = patch_size , device = operation_settings .get ("device" ), dtype = torch . float32 )
357357 self .text_embedding = nn .Sequential (
358358 operations .Linear (text_dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )), nn .GELU (approximate = 'tanh' ),
359359 operations .Linear (dim , dim , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" )))
@@ -411,7 +411,7 @@ def forward_orig(
411411 List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
412412 """
413413 # embeddings
414- x = self .patch_embedding (x )
414+ x = self .patch_embedding (x . float ()). to ( x . dtype )
415415 grid_sizes = x .shape [2 :]
416416 x = x .flatten (2 ).transpose (1 , 2 )
417417
0 commit comments