Skip to content

Commit c74b0b7

Browse files
Lunwen Hefacebook-github-bot
authored andcommitted
Replace torch.empty with torch.zeros (pytorch#6478)
Summary: X-link: pytorch/ao#1157 It turns out that it is unsafe to use `torch.empty` in oss environment because `torch.empty` creates tensor with uninitialized data. That means the buffer could be initialized with random values depends on what is left on that piece of memory. This causes code to generate inconsistent behavior. This PR replaces `torch.empty` with `torch.zeros` to make sure that they are properly initialized and avoid inconsistent behaviors. Reviewed By: msaroufim Differential Revision: D64875312
1 parent e93ad5f commit c74b0b7

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

examples/models/llama/source_transformation/lora.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def __init__(
7070
precision=precision,
7171
scales_precision=scales_precision,
7272
)
73+
# TODO(lunwenh): Remove this once TorchAO's commit pin in ExecuTorch is updated to include this PR
74+
super().zeros = torch.zeros_like(super().zeros)
7375
self.adaptor = LoRAAdaptorLinear(
7476
in_features,
7577
out_features,

examples/models/llama/source_transformation/pre_quantization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
4646
precision=precision,
4747
scales_precision=scales_precision,
4848
)
49+
# TODO(lunwenh): Remove this once TorchAO's commit pin in ExecuTorch is updated to include this PR
50+
new_linear.zeros = torch.zeros_like(new_linear.zeros)
4951
return new_linear
5052

5153
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)

examples/models/llama/source_transformation/quantize.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def __init__(
375375
self.in_features = in_features
376376
self.out_features = out_features
377377
self.register_buffer(
378-
"weight", torch.empty((out_features, in_features), dtype=torch.int8)
378+
"weight", torch.zeros((out_features, in_features), dtype=torch.int8)
379379
)
380380
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
381381

@@ -448,18 +448,18 @@ def __init__(
448448
# currently storing unpacked int8 weights
449449
self.register_buffer(
450450
"weight",
451-
torch.empty((out_features, in_features), dtype=torch.int8),
451+
torch.zeros((out_features, in_features), dtype=torch.int8),
452452
)
453453
self.register_buffer(
454454
"scales",
455-
torch.empty(
455+
torch.zeros(
456456
(out_features),
457457
dtype=torch.float32,
458458
),
459459
)
460460
self.register_buffer(
461461
"zeros",
462-
torch.empty(
462+
torch.zeros(
463463
(out_features),
464464
dtype=torch.float32,
465465
),
@@ -632,15 +632,15 @@ def __init__(
632632
if not packed:
633633
self.register_buffer(
634634
"weight",
635-
torch.empty(
635+
torch.zeros(
636636
(vocab_size, embedding_dim), dtype=torch.int8, device=device
637637
),
638638
)
639639
else: # packed
640640
if bitwidth == 2:
641641
self.register_buffer(
642642
"weight",
643-
torch.empty(
643+
torch.zeros(
644644
(vocab_size, embedding_dim // 4),
645645
dtype=torch.uint8,
646646
device=device,
@@ -649,7 +649,7 @@ def __init__(
649649
elif bitwidth == 4:
650650
self.register_buffer(
651651
"weight",
652-
torch.empty(
652+
torch.zeros(
653653
(vocab_size, embedding_dim // 2),
654654
dtype=torch.uint8,
655655
device=device,

0 commit comments

Comments
 (0)