Skip to content

Commit dde3df3

Browse files
authored
Replace torch.empty with torch.zeros
Differential Revision: D64875312 Pull Request resolved: #6478
1 parent 8234c14 commit dde3df3

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

examples/models/llama/source_transformation/lora.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def __init__(
7070
precision=precision,
7171
scales_precision=scales_precision,
7272
)
73+
# TODO(lunwenh): Remove this once TorchAO's commit pin in ExecuTorch is updated to include this PR
74+
self.zeros = torch.zeros_like(self.zeros)
7375
self.adaptor = LoRAAdaptorLinear(
7476
in_features,
7577
out_features,

examples/models/llama/source_transformation/pre_quantization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
4646
precision=precision,
4747
scales_precision=scales_precision,
4848
)
49+
# TODO(lunwenh): Remove this once TorchAO's commit pin in ExecuTorch is updated to include this PR
50+
new_linear.zeros = torch.zeros_like(new_linear.zeros)
4951
return new_linear
5052

5153
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)

examples/models/llama/source_transformation/quantize.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def __init__(
375375
self.in_features = in_features
376376
self.out_features = out_features
377377
self.register_buffer(
378-
"weight", torch.empty((out_features, in_features), dtype=torch.int8)
378+
"weight", torch.zeros((out_features, in_features), dtype=torch.int8)
379379
)
380380
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
381381

@@ -448,18 +448,18 @@ def __init__(
448448
# currently storing unpacked int8 weights
449449
self.register_buffer(
450450
"weight",
451-
torch.empty((out_features, in_features), dtype=torch.int8),
451+
torch.zeros((out_features, in_features), dtype=torch.int8),
452452
)
453453
self.register_buffer(
454454
"scales",
455-
torch.empty(
455+
torch.zeros(
456456
(out_features),
457457
dtype=torch.float32,
458458
),
459459
)
460460
self.register_buffer(
461461
"zeros",
462-
torch.empty(
462+
torch.zeros(
463463
(out_features),
464464
dtype=torch.float32,
465465
),
@@ -632,15 +632,15 @@ def __init__(
632632
if not packed:
633633
self.register_buffer(
634634
"weight",
635-
torch.empty(
635+
torch.zeros(
636636
(vocab_size, embedding_dim), dtype=torch.int8, device=device
637637
),
638638
)
639639
else: # packed
640640
if bitwidth == 2:
641641
self.register_buffer(
642642
"weight",
643-
torch.empty(
643+
torch.zeros(
644644
(vocab_size, embedding_dim // 4),
645645
dtype=torch.uint8,
646646
device=device,
@@ -649,7 +649,7 @@ def __init__(
649649
elif bitwidth == 4:
650650
self.register_buffer(
651651
"weight",
652-
torch.empty(
652+
torch.zeros(
653653
(vocab_size, embedding_dim // 2),
654654
dtype=torch.uint8,
655655
device=device,

0 commit comments

Comments
 (0)