Skip to content

Commit ac1eece

Browse files
committed
undo dtype changes
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 9039eb5 commit ac1eece

File tree

1 file changed

+0
-6
lines changed
  • src/compressed_tensors/transform/utils

1 file changed

+0
-6
lines changed

src/compressed_tensors/transform/utils/matrix.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,10 @@ def apply_transform_weight(
6565
head_dim = weight.shape[0]
6666
num_heads = value.shape[axis] // head_dim
6767

68-
value_dtype = value.dtype
69-
value = value.to(torch.float64)
70-
weight = weight.to(torch.float64)
71-
7268
value = value.unflatten(axis, (num_heads, head_dim))
7369
value = fn(weight, value)
7470
value = value.flatten(axis - 1, axis)
7571

76-
value = value.to(value_dtype)
77-
7872
return value
7973

8074

0 commit comments

Comments
 (0)