Skip to content

Commit 57595f5

Browse files
Lunwen Hefacebook-github-bot
authored andcommitted
Replace torch.empty with torch.zeros (pytorch#1157)
Summary: X-link: pytorch/executorch#6478 It turns out that it is unsafe to use `torch.empty` in oss environment because `torch.empty` creates tensor with uninitialized data. That means the buffer could be initialized with random values depends on what is left on that piece of memory. This causes code to generate inconsistent behavior. This PR replaces `torch.empty` with `torch.zeros` to make sure that they are properly initialized and avoid inconsistent behaviors. Reviewed By: msaroufim Differential Revision: D64875312
1 parent 39a473e commit 57595f5

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

torchao/quantization/GPTQ.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -583,12 +583,12 @@ def __init__(
583583
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
584584
self.register_buffer(
585585
"weight",
586-
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32, device=device)
586+
torch.zeros((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32, device=device)
587587
)
588588
self.dtype = dtype
589589
self.register_buffer(
590590
"scales_and_zeros",
591-
torch.empty((in_features // groupsize, out_features, 2), dtype=self.scales_precision, device=device)
591+
torch.zeros((in_features // groupsize, out_features, 2), dtype=self.scales_precision, device=device)
592592
)
593593

594594
def forward(self, input: torch.Tensor) -> torch.Tensor:
@@ -935,18 +935,18 @@ def __init__(
935935
# currently storing unpacked int8 weights
936936
self.register_buffer(
937937
"weight",
938-
torch.empty((out_features, in_features), dtype=torch.int8),
938+
torch.zeros((out_features, in_features), dtype=torch.int8),
939939
)
940940
self.register_buffer(
941941
"scales",
942-
torch.empty(
942+
torch.zeros(
943943
(out_features, in_features // groupsize),
944944
dtype=scales_precision,
945945
),
946946
)
947947
self.register_buffer(
948948
"zeros",
949-
torch.empty(
949+
torch.zeros(
950950
(out_features, in_features // groupsize),
951951
dtype=scales_precision,
952952
),

0 commit comments

Comments
 (0)