Skip to content

Commit 7d47bee

Browse files
build: manually update PyTorch version (#4415)
This commit sets the PyTorch and TorchVision versions to nightly release 2025-12-22. Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent 7755008 commit 7d47bee

File tree

6 files changed

+33
-9
lines changed

6 files changed

+33
-9
lines changed

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12299,9 +12299,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1229912299
" return %2 : !torch.int\n"
1230012300
" }\n"
1230112301
" func.func @\"__torch_mlir_dtype_fn.aten.softshrink\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
12302+
" %none = torch.constant.none\n"
12303+
" %str = torch.constant.str \"AssertionError: \"\n"
1230212304
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12303-
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
12304-
" return %1 : !torch.int\n"
12305+
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12306+
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
12307+
" torch.prim.If %2 -> () {\n"
12308+
" torch.prim.If.yield\n"
12309+
" } else {\n"
12310+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12311+
" torch.prim.If.yield\n"
12312+
" }\n"
12313+
" %3 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
12314+
" return %3 : !torch.int\n"
1230512315
" }\n"
1230612316
" func.func @\"__torch_mlir_dtype_fn.aten.polar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
1230712317
" %int9 = torch.constant.int 9\n"
@@ -16854,11 +16864,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1685416864
" %int6 = torch.constant.int 6\n"
1685516865
" %none = torch.constant.none\n"
1685616866
" %str = torch.constant.str \"AssertionError: \"\n"
16867+
" %int15 = torch.constant.int 15\n"
16868+
" %true = torch.constant.bool true\n"
1685716869
" %int5 = torch.constant.int 5\n"
1685816870
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1685916871
" %1 = torch.prim.If %arg2 -> (!torch.int) {\n"
1686016872
" %2 = torch.aten.eq.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
16861-
" torch.prim.If %2 -> () {\n"
16873+
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
16874+
" torch.prim.If.yield %true : !torch.bool\n"
16875+
" } else {\n"
16876+
" %4 = torch.aten.eq.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
16877+
" torch.prim.If.yield %4 : !torch.bool\n"
16878+
" }\n"
16879+
" torch.prim.If %3 -> () {\n"
1686216880
" torch.prim.If.yield\n"
1686316881
" } else {\n"
1686416882
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5238,7 +5238,6 @@ class DecomposeAtenUpsampleNearestVecOp
52385238
using OpRewritePattern<UpsampleVecOp>::OpRewritePattern;
52395239
LogicalResult matchAndRewrite(UpsampleVecOp op,
52405240
PatternRewriter &rewriter) const override {
5241-
Value scales = op.getScaleFactors();
52425241
static_assert(std::is_same_v<UpsampleVecOp, AtenUpsampleNearest1dVecOp> ||
52435242
std::is_same_v<UpsampleVecOp, AtenUpsampleNearest2dVecOp>);
52445243
Value cstMode = Torch::ConstantStrOp::create(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3509,6 +3509,12 @@
35093509
"MaxPool2dStaticCeilModeTrueReduceOutputModule_basic",
35103510
}
35113511

3512+
if torch_version_for_comparison() > version.parse("2.10.0.dev"):
3513+
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
3514+
"Aten_CastLongModule_basic",
3515+
"Aten_CastFloatModule_basic",
3516+
}
3517+
35123518
if torch_version_for_comparison() < version.parse("2.4.0.dev"):
35133519
STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - {
35143520
"AtenIntMM_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3013,9 +3013,10 @@ def aten〇hardshrink〡dtype(self_rank_dtype: Tuple[int, int], lambd: Union[int
30133013
return torch.int64
30143014
return self_dtype
30153015

3016-
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, lambd=0.5))
3016+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, lambd=0.5, error_types={*all_integer_dtypes()}))
30173017
def aten〇softshrink〡dtype(self_rank_dtype: Tuple[int, int], lambd: Union[int, float, complex] = 0.5) -> int:
30183018
self_rank, self_dtype = self_rank_dtype
3019+
assert not is_integer_dtype(self_dtype)
30193020
return _get_dtype_of_floating_point_op(self_dtype)
30203021

30213022

@@ -6085,12 +6086,12 @@ def aten〇softmax〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dty
60856086
# _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) +
60866087
_check_tensors_with_the_same_dtype(
60876088
num_of_tensors=1,
6088-
error_types=(all_integer_dtypes() + all_complex_dtypes() + [torch.bfloat16, torch.float32, torch.float64]),
6089+
error_types=(all_integer_dtypes() + all_complex_dtypes() + [torch.float32, torch.float64]),
60896090
dim=0, half_to_float=True))
60906091
def aten〇_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, half_to_float: bool) -> int:
60916092
self_rank, self_dtype = self_rank_dtype
60926093
if half_to_float:
6093-
assert self_dtype == torch.float16
6094+
assert self_dtype == torch.float16 or self_dtype == torch.bfloat16
60946095
return torch.float32
60956096
return self_dtype
60966097

pytorch-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
-f https://download.pytorch.org/whl/nightly/cpu/torch/
22
--pre
3-
torch==2.10.0.dev20251016
3+
torch==2.11.0.dev20251222

torchvision-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
-f https://download.pytorch.org/whl/nightly/cpu/torchvision/
22
--pre
3-
torchvision==0.25.0.dev20251016
3+
torchvision==0.25.0.dev20251222

0 commit comments

Comments
 (0)