Skip to content

Commit 79083cd

Browse files
committed
fix conflict after rebase
1 parent 1b09f44 commit 79083cd

File tree

2 files changed

+4
-8
lines changed

2 files changed

+4
-8
lines changed

intel_pytorch_extension_py/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def get_auto_optimization():
3232
def get_train():
3333
return core.get_train()
3434

35-
def enable_auto_mix_precision(mixed_dtype = torch.bfloat16, train, configure_file = None):
35+
def enable_auto_mix_precision(mixed_dtype = torch.bfloat16, train = False, configure_file = None):
3636
if mixed_dtype == torch.bfloat16:
3737
core.enable_mix_bf16_fp32()
3838
core.disable_mix_int8_fp32()

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -794,16 +794,12 @@ at::Tensor AtenIpexCPUDev::dil_linear(
794794

795795
dil::tensor y = dbl::linear::linear_impl(x, w, b, output_scale);
796796

797-
auto input_size = self.sizes();
798-
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
799-
output_size.push_back(weight.size(0));
800-
801797
auto aten_output = dbl::comm::gen_aten_tensor_by(std::move(y));
802798

803799
if (check_auto_mix_int8_fp32() && check_int8_calibration()) {
804800
insert_or_updata_observer(self, aten_output, "Linear");
805801
}
806-
802+
807803
if (self.dim() > 2) {
808804
auto input_size = self.sizes();
809805
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
@@ -1358,7 +1354,7 @@ at::Tensor AtenIpexCPUDev::dil_relu(const at::Tensor& input) {
13581354
dbl::comm::reorder_to_dtype(input, at::kFloat);
13591355
}
13601356
} else {
1361-
dbl::comm::reorder_to_bf16_for_mix_prec(input);
1357+
dbl::comm::reorder_to_bf16_for_mix_prec(input, true);
13621358
}
13631359

13641360
const dil::tensor& x = dbl::comm::try_gen_dil_tensor(input);
@@ -1388,7 +1384,7 @@ at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) {
13881384
dbl::comm::reorder_to_dtype(input, at::kFloat);
13891385
}
13901386
} else {
1391-
dbl::comm::reorder_to_bf16_for_mix_prec(input);
1387+
dbl::comm::reorder_to_bf16_for_mix_prec(input, true);
13921388
}
13931389

13941390
if (check_auto_mix_int8_fp32() && check_int8_calibration()) {

0 commit comments

Comments
 (0)