diff --git a/.automation_scripts/parse_xml_results.py b/.automation_scripts/parse_xml_results.py
new file mode 100644
index 0000000000000..7db2e1ce9233c
--- /dev/null
+++ b/.automation_scripts/parse_xml_results.py
@@ -0,0 +1,178 @@
+""" The Python PyTorch testing script.
+##
+# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+"""
+
+import xml.etree.ElementTree as ET
+from pathlib import Path
+from typing import Any, Dict, Tuple
+
+# Backends list
+BACKENDS_LIST = [
+ "dist-gloo",
+ "dist-nccl"
+]
+
+TARGET_WORKFLOW = "--rerun-disabled-tests"
+
+def get_job_id(report: Path) -> int:
+ # [Job id in artifacts]
+ # Retrieve the job id from the report path. In our GHA workflows, we append
+ # the job id to the end of the report name, so `report` looks like:
+ # unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
+ # and we want to get `5596745227` out of it.
+ try:
+ return int(report.parts[0].rpartition("_")[2])
+ except ValueError:
+ return -1
+
+def is_rerun_disabled_tests(root: ET.ElementTree) -> bool:
+ """
+ Check if the test report is coming from rerun_disabled_tests workflow
+ """
+ skipped = root.find(".//*skipped")
+ # Need to check against None here, if not skipped doesn't work as expected
+ if skipped is None:
+ return False
+
+ message = skipped.attrib.get("message", "")
+ return TARGET_WORKFLOW in message or "num_red" in message
+
+def parse_xml_report(
+ tag: str,
+ report: Path,
+ workflow_id: int,
+ workflow_run_attempt: int,
+ work_flow_name: str
+) -> Dict[Tuple[str], Dict[str, Any]]:
+ """Convert a test report xml file into a JSON-serializable list of test cases."""
+ print(f"Parsing {tag}s for test report: {report}")
+
+ job_id = get_job_id(report)
+ print(f"Found job id: {job_id}")
+
+ test_cases: Dict[Tuple[str], Dict[str, Any]] = {}
+
+ root = ET.parse(report)
+ # TODO: unlike unittest, pytest-flakefinder used by rerun disabled tests for test_ops
+ # includes skipped messages multiple times (50 times by default). This slows down
+ # this script too much (O(n)) because it tries to gather all the stats. This should
+ # be fixed later in the way we use pytest-flakefinder. A zipped test report from rerun
+ # disabled test is only few MB, but will balloon up to a much bigger XML file after
+ # extracting from a dozen to few hundred MB
+ if is_rerun_disabled_tests(root):
+ return test_cases
+
+ for test_case in root.iter(tag):
+ case = process_xml_element(test_case)
+ if tag == 'testcase':
+ case["workflow_id"] = workflow_id
+ case["workflow_run_attempt"] = workflow_run_attempt
+ case["job_id"] = job_id
+ case["work_flow_name"] = work_flow_name
+
+ # [invoking file]
+ # The name of the file that the test is located in is not necessarily
+ # the same as the name of the file that invoked the test.
+ # For example, `test_jit.py` calls into multiple other test files (e.g.
+ # jit/test_dce.py). For sharding/test selection purposes, we want to
+ # record the file that invoked the test.
+ #
+ # To do this, we leverage an implementation detail of how we write out
+ # tests (https://bit.ly/3ajEV1M), which is that reports are created
+ # under a folder with the same name as the invoking file.
+ case_name = report.parent.name
+ for ind in range(len(BACKENDS_LIST)):
+ if BACKENDS_LIST[ind] in report.parts:
+ case_name = case_name + "_" + BACKENDS_LIST[ind]
+ break
+ case["invoking_file"] = case_name
+ test_cases[ ( case["invoking_file"], case["classname"], case["name"], case["work_flow_name"] ) ] = case
+ elif tag == 'testsuite':
+ case["work_flow_name"] = work_flow_name
+ case["invoking_xml"] = report.name
+ case["running_time_xml"] = case["time"]
+ case_name = report.parent.name
+ for ind in range(len(BACKENDS_LIST)):
+ if BACKENDS_LIST[ind] in report.parts:
+ case_name = case_name + "_" + BACKENDS_LIST[ind]
+ break
+ case["invoking_file"] = case_name
+
+ test_cases[ ( case["invoking_file"], case["invoking_xml"], case["work_flow_name"] ) ] = case
+
+ return test_cases
+
+def process_xml_element(element: ET.Element) -> Dict[str, Any]:
+ """Convert a test suite element into a JSON-serializable dict."""
+ ret: Dict[str, Any] = {}
+
+ # Convert attributes directly into dict elements.
+ # e.g.
+ #
+ # becomes:
+ # {"name": "test_foo", "classname": "test_bar"}
+ ret.update(element.attrib)
+
+ # The XML format encodes all values as strings. Convert to ints/floats if
+ # possible to make aggregation possible in Rockset.
+ for k, v in ret.items():
+ try:
+ ret[k] = int(v)
+ except ValueError:
+ pass
+ try:
+ ret[k] = float(v)
+ except ValueError:
+ pass
+
+ # Convert inner and outer text into special dict elements.
+ # e.g.
+ # my_inner_text my_tail
+ # becomes:
+ # {"text": "my_inner_text", "tail": " my_tail"}
+ if element.text and element.text.strip():
+ ret["text"] = element.text
+ if element.tail and element.tail.strip():
+ ret["tail"] = element.tail
+
+ # Convert child elements recursively, placing them at a key:
+ # e.g.
+ #
+ # hello
+ # world
+ # another
+ #
+ # becomes
+ # {
+ # "foo": [{"text": "hello"}, {"text": "world"}],
+ # "bar": {"text": "another"}
+ # }
+ for child in element:
+ if child.tag not in ret:
+ ret[child.tag] = process_xml_element(child)
+ else:
+ # If there are multiple tags with the same name, they should be
+ # coalesced into a list.
+ if not isinstance(ret[child.tag], list):
+ ret[child.tag] = [ret[child.tag]]
+ ret[child.tag].append(process_xml_element(child))
+ return ret
\ No newline at end of file
diff --git a/.automation_scripts/run_pytorch_unit_tests.py b/.automation_scripts/run_pytorch_unit_tests.py
new file mode 100644
index 0000000000000..514afd19624c3
--- /dev/null
+++ b/.automation_scripts/run_pytorch_unit_tests.py
@@ -0,0 +1,518 @@
+#!/usr/bin/env python3
+
+""" The Python PyTorch testing script.
+##
+# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+"""
+
+import argparse
+import os
+import shutil
+import subprocess
+from subprocess import STDOUT, CalledProcessError
+
+from collections import namedtuple
+from datetime import datetime
+from pathlib import Path
+from parse_xml_results import (
+ parse_xml_report
+)
+from pprint import pprint
+from typing import Any, Dict, List
+
+# unit test status list
+UT_STATUS_LIST = [
+ "PASSED",
+ "MISSED",
+ "SKIPPED",
+ "FAILED",
+ "XFAILED",
+ "ERROR"
+]
+
+DEFAULT_CORE_TESTS = [
+ "test_nn",
+ "test_torch",
+ "test_cuda",
+ "test_ops",
+ "test_unary_ufuncs",
+ "test_autograd",
+ "inductor/test_torchinductor"
+]
+
+DISTRIBUTED_CORE_TESTS = [
+ "distributed/test_c10d_common",
+ "distributed/test_c10d_nccl",
+ "distributed/test_distributed_spawn"
+]
+
+CONSOLIDATED_LOG_FILE_NAME="pytorch_unit_tests.log"
+
+def parse_xml_reports_as_dict(workflow_run_id, workflow_run_attempt, tag, workflow_name, path="."):
+ test_cases = {}
+ items_list = os.listdir(path)
+ for dir in items_list:
+ new_dir = path + '/' + dir + '/'
+ if os.path.isdir(new_dir):
+ for xml_report in Path(new_dir).glob("**/*.xml"):
+ test_cases.update(
+ parse_xml_report(
+ tag,
+ xml_report,
+ workflow_run_id,
+ workflow_run_attempt,
+ workflow_name
+ )
+ )
+ return test_cases
+
+def get_test_status(test_case):
+ # In order of priority: S=skipped, F=failure, E=error, P=pass
+ if "skipped" in test_case and test_case["skipped"]:
+ type_message = test_case["skipped"]
+ if type_message.__contains__('type') and type_message['type'] == "pytest.xfail":
+ return "XFAILED"
+ else:
+ return "SKIPPED"
+ elif "failure" in test_case and test_case["failure"]:
+ return "FAILED"
+ elif "error" in test_case and test_case["error"]:
+ return "ERROR"
+ else:
+ return "PASSED"
+
+def get_test_message(test_case, status=None):
+ if status == "SKIPPED":
+ return test_case["skipped"] if "skipped" in test_case else ""
+ elif status == "FAILED":
+ return test_case["failure"] if "failure" in test_case else ""
+ elif status == "ERROR":
+ return test_case["error"] if "error" in test_case else ""
+ else:
+ if "skipped" in test_case:
+ return test_case["skipped"]
+ elif "failure" in test_case:
+ return test_case["failure"]
+ elif "error" in test_case:
+ return test_case["error"]
+ else:
+ return ""
+
+def get_test_file_running_time(test_suite):
+ if test_suite.__contains__('time'):
+ return test_suite["time"]
+ return 0
+
+def get_test_running_time(test_case):
+ if test_case.__contains__('time'):
+ return test_case["time"]
+ return ""
+
+def summarize_xml_files(path, workflow_name):
+ # statistics
+ TOTAL_TEST_NUM = 0
+ TOTAL_PASSED_NUM = 0
+ TOTAL_SKIPPED_NUM = 0
+ TOTAL_XFAIL_NUM = 0
+ TOTAL_FAILED_NUM = 0
+ TOTAL_ERROR_NUM = 0
+ TOTAL_EXECUTION_TIME = 0
+
+ #parse the xml files
+ test_cases = parse_xml_reports_as_dict(-1, -1, 'testcase', workflow_name, path)
+ test_suites = parse_xml_reports_as_dict(-1, -1, 'testsuite', workflow_name, path)
+ test_file_and_status = namedtuple("test_file_and_status", ["file_name", "status"])
+ # results dict
+ res = {}
+ res_item_list = [ "PASSED", "SKIPPED", "XFAILED", "FAILED", "ERROR" ]
+ test_file_items = set()
+ for (k,v) in list(test_suites.items()):
+ file_name = k[0]
+ if not file_name in test_file_items:
+ test_file_items.add(file_name)
+ # initialization
+ for item in res_item_list:
+ temp_item = test_file_and_status(file_name, item)
+ res[temp_item] = {}
+ temp_item_statistics = test_file_and_status(file_name, "STATISTICS")
+ res[temp_item_statistics] = {'TOTAL': 0, 'PASSED': 0, 'SKIPPED': 0, 'XFAILED': 0, 'FAILED': 0, 'ERROR': 0, 'EXECUTION_TIME': 0}
+ test_running_time = get_test_file_running_time(v)
+ res[temp_item_statistics]["EXECUTION_TIME"] += test_running_time
+ TOTAL_EXECUTION_TIME += test_running_time
+ else:
+ test_tuple_key_statistics = test_file_and_status(file_name, "STATISTICS")
+ test_running_time = get_test_file_running_time(v)
+ res[test_tuple_key_statistics]["EXECUTION_TIME"] += test_running_time
+ TOTAL_EXECUTION_TIME += test_running_time
+
+ for (k,v) in list(test_cases.items()):
+ file_name = k[0]
+ class_name = k[1]
+ test_name = k[2]
+ combined_name = file_name + "::" + class_name + "::" + test_name
+ test_status = get_test_status(v)
+ test_running_time = get_test_running_time(v)
+ test_message = get_test_message(v, test_status)
+ test_info_value = ""
+ test_tuple_key_status = test_file_and_status(file_name, test_status)
+ test_tuple_key_statistics = test_file_and_status(file_name, "STATISTICS")
+ TOTAL_TEST_NUM += 1
+ res[test_tuple_key_statistics]["TOTAL"] += 1
+ if test_status == "PASSED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["PASSED"] += 1
+ TOTAL_PASSED_NUM += 1
+ elif test_status == "SKIPPED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["SKIPPED"] += 1
+ TOTAL_SKIPPED_NUM += 1
+ elif test_status == "XFAILED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["XFAILED"] += 1
+ TOTAL_XFAIL_NUM += 1
+ elif test_status == "FAILED":
+ test_info_value = test_message
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["FAILED"] += 1
+ TOTAL_FAILED_NUM += 1
+ elif test_status == "ERROR":
+ test_info_value = test_message
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["ERROR"] += 1
+ TOTAL_ERROR_NUM += 1
+
+ # generate statistics_dict
+ statistics_dict = {}
+ statistics_dict["TOTAL"] = TOTAL_TEST_NUM
+ statistics_dict["PASSED"] = TOTAL_PASSED_NUM
+ statistics_dict["SKIPPED"] = TOTAL_SKIPPED_NUM
+ statistics_dict["XFAILED"] = TOTAL_XFAIL_NUM
+ statistics_dict["FAILED"] = TOTAL_FAILED_NUM
+ statistics_dict["ERROR"] = TOTAL_ERROR_NUM
+ statistics_dict["EXECUTION_TIME"] = TOTAL_EXECUTION_TIME
+ aggregate_item = workflow_name + "_aggregate"
+ total_item = test_file_and_status(aggregate_item, "STATISTICS")
+ res[total_item] = statistics_dict
+
+ return res
+
+def run_command_and_capture_output(cmd):
+ try:
+ print(f"Running command '{cmd}'")
+ with open(CONSOLIDATED_LOG_FILE_PATH, "a+") as output_file:
+ print(f"========================================", file=output_file, flush=True)
+ print(f"[RUN_PYTORCH_UNIT_TESTS] Running command '{cmd}'", file=output_file, flush=True) # send to consolidated file as well
+ print(f"========================================", file=output_file, flush=True)
+ p = subprocess.run(cmd, shell=True, stdout=output_file, stderr=STDOUT, text=True)
+ except CalledProcessError as e:
+ print(f"ERROR: Cmd {cmd} failed with return code: {e.returncode}!")
+
+def run_entire_tests(workflow_name, test_shell_path, overall_logs_path_current_run, test_reports_src):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_entire_tests/"
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_entire_tests/"
+ elif workflow_name == "inductor":
+ os.environ['TEST_CONFIG'] = 'inductor'
+ copied_logs_path = overall_logs_path_current_run + "inductor_xml_results_entire_tests/"
+ # use test.sh for tests execution
+ run_command_and_capture_output(test_shell_path)
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ entire_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+ return entire_results_dict
+
+def run_priority_tests(workflow_name, test_run_test_path, overall_logs_path_current_run, test_reports_src):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_priority_tests/"
+ # use run_test.py for tests execution
+ default_priority_test_suites = " ".join(DEFAULT_CORE_TESTS)
+ command = "python3 " + test_run_test_path + " --include " + default_priority_test_suites + " --exclude-jit-executor --exclude-distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0,1'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_priority_tests/"
+ # use run_test.py for tests execution
+ distributed_priority_test_suites = " ".join(DISTRIBUTED_CORE_TESTS)
+ command = "python3 " + test_run_test_path + " --include " + distributed_priority_test_suites + " --distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ priority_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+
+ return priority_results_dict
+
+def run_selected_tests(workflow_name, test_run_test_path, overall_logs_path_current_run, test_reports_src, selected_list):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_selected_tests/"
+ # use run_test.py for tests execution
+ default_selected_test_suites = " ".join(selected_list)
+ command = "python3 " + test_run_test_path + " --include " + default_selected_test_suites + " --exclude-jit-executor --exclude-distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0,1'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_selected_tests/"
+ # use run_test.py for tests execution
+ distributed_selected_test_suites = " ".join(selected_list)
+ command = "python3 " + test_run_test_path + " --include " + distributed_selected_test_suites + " --distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "inductor":
+ os.environ['TEST_CONFIG'] = 'inductor'
+ copied_logs_path = overall_logs_path_current_run + "inductor_xml_results_selected_tests/"
+ inductor_selected_test_suites = ""
+ non_inductor_selected_test_suites = ""
+ for item in selected_list:
+ if "inductor/" in item:
+ inductor_selected_test_suites += item
+ inductor_selected_test_suites += " "
+ else:
+ non_inductor_selected_test_suites += item
+ non_inductor_selected_test_suites += " "
+ if inductor_selected_test_suites != "":
+ inductor_selected_test_suites = inductor_selected_test_suites[:-1]
+ command = "python3 " + test_run_test_path + " --include " + inductor_selected_test_suites + " --verbose"
+ run_command_and_capture_output(command)
+ if non_inductor_selected_test_suites != "":
+ non_inductor_selected_test_suites = non_inductor_selected_test_suites[:-1]
+ command = "python3 " + test_run_test_path + " --inductor --include " + non_inductor_selected_test_suites + " --verbose"
+ run_command_and_capture_output(command)
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ selected_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+
+ return selected_results_dict
+
+def run_test_and_summarize_results(
+ pytorch_root_dir: str,
+ priority_tests: bool,
+ test_config: List[str],
+ default_list: List[str],
+ distributed_list: List[str],
+ inductor_list: List[str],
+ skip_rerun: bool) -> Dict[str, Any]:
+
+ # copy current environment variables
+ _environ = dict(os.environ)
+
+ # modify path
+ test_shell_path = pytorch_root_dir + "/.ci/pytorch/test.sh"
+ test_run_test_path = pytorch_root_dir + "/test/run_test.py"
+ repo_test_log_folder_path = pytorch_root_dir + "/.automation_logs/"
+ test_reports_src = pytorch_root_dir + "/test/test-reports/"
+ run_test_python_file = pytorch_root_dir + "/test/run_test.py"
+
+ # change directory to pytorch root
+ os.chdir(pytorch_root_dir)
+
+ # all test results dict
+ res_all_tests_dict = {}
+
+ # patterns
+ search_text = "--reruns=2"
+ replace_text = "--reruns=0"
+
+ # create logs folder
+ if not os.path.exists(repo_test_log_folder_path):
+ os.mkdir(repo_test_log_folder_path)
+
+ # Set common environment variables for all scenarios
+ os.environ['CI'] = '1'
+ os.environ['PYTORCH_TEST_WITH_ROCM'] = '1'
+ os.environ['HSA_FORCE_FINE_GRAIN_PCIE'] = '1'
+ os.environ['PYTORCH_TESTING_DEVICE_ONLY_FOR'] = 'cuda'
+ os.environ['CONTINUE_THROUGH_ERROR'] = 'True'
+ if skip_rerun:
+ # modify run_test.py in-place
+ with open(run_test_python_file, 'r') as file:
+ data = file.read()
+ data = data.replace(search_text, replace_text)
+ with open(run_test_python_file, 'w') as file:
+ file.write(data)
+
+ # Time stamp
+ current_datetime = datetime.now().strftime("%Y%m%d_%H-%M-%S")
+ print("Current date & time : ", current_datetime)
+ # performed as Job ID
+ str_current_datetime = str(current_datetime)
+ overall_logs_path_current_run = repo_test_log_folder_path + str_current_datetime + "/"
+ os.mkdir(overall_logs_path_current_run)
+
+ global CONSOLIDATED_LOG_FILE_PATH
+ CONSOLIDATED_LOG_FILE_PATH = overall_logs_path_current_run + CONSOLIDATED_LOG_FILE_NAME
+
+ # Check multi gpu availability if distributed tests are enabled
+ if ("distributed" in test_config) or len(distributed_list) != 0:
+ check_num_gpus_for_distributed()
+
+ # Install test requirements
+ command = "pip3 install -r requirements.txt && pip3 install -r .ci/docker/requirements-ci.txt"
+ run_command_and_capture_output(command)
+
+ # Run entire tests for each workflow
+ if not priority_tests and not default_list and not distributed_list and not inductor_list:
+ # run entire tests for default, distributed and inductor workflows → use test.sh
+ if not test_config:
+ check_num_gpus_for_distributed()
+ # default test process
+ res_default_all = run_entire_tests("default", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_all
+ # distributed test process
+ res_distributed_all = run_entire_tests("distributed", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_all
+ # inductor test process
+ res_inductor_all = run_entire_tests("inductor", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["inductor"] = res_inductor_all
+ else:
+ workflow_list = []
+ for item in test_config:
+ workflow_list.append(item)
+ if "default" in workflow_list:
+ res_default_all = run_entire_tests("default", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_all
+ if "distributed" in workflow_list:
+ res_distributed_all = run_entire_tests("distributed", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_all
+ if "inductor" in workflow_list:
+ res_inductor_all = run_entire_tests("inductor", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["inductor"] = res_inductor_all
+ # Run priority test for each workflow
+ elif priority_tests and not default_list and not distributed_list and not inductor_list:
+ if not test_config:
+ check_num_gpus_for_distributed()
+ # default test process
+ res_default_priority = run_priority_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_priority
+ # distributed test process
+ res_distributed_priority = run_priority_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_priority
+ # will not run inductor priority tests
+ print("Inductor priority tests cannot run since no core tests defined with inductor workflow.")
+ else:
+ workflow_list = []
+ for item in test_config:
+ workflow_list.append(item)
+ if "default" in workflow_list:
+ res_default_priority = run_priority_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_priority
+ if "distributed" in workflow_list:
+ res_distributed_priority = run_priority_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_priority
+ if "inductor" in workflow_list:
+ print("Inductor priority tests cannot run since no core tests defined with inductor workflow.")
+ # Run specified tests for each workflow
+ elif (default_list or distributed_list or inductor_list) and not test_config and not priority_tests:
+ if default_list:
+ default_workflow_list = []
+ for item in default_list:
+ default_workflow_list.append(item)
+ res_default_selected = run_selected_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src, default_workflow_list)
+ res_all_tests_dict["default"] = res_default_selected
+ if distributed_list:
+ distributed_workflow_list = []
+ for item in distributed_list:
+ distributed_workflow_list.append(item)
+ res_distributed_selected = run_selected_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src, distributed_workflow_list)
+ res_all_tests_dict["distributed"] = res_distributed_selected
+ if inductor_list:
+ inductor_workflow_list = []
+ for item in inductor_list:
+ inductor_workflow_list.append(item)
+ res_inductor_selected = run_selected_tests("inductor", test_run_test_path, overall_logs_path_current_run, test_reports_src, inductor_workflow_list)
+ res_all_tests_dict["inductor"] = res_inductor_selected
+ else:
+ raise Exception("Invalid test configurations!")
+
+ # restore environment variables
+ os.environ.clear()
+ os.environ.update(_environ)
+
+ # restore files
+ if skip_rerun:
+ # modify run_test.py in-place
+ with open(run_test_python_file, 'r') as file:
+ data = file.read()
+ data = data.replace(replace_text, search_text)
+ with open(run_test_python_file, 'w') as file:
+ file.write(data)
+
+ return res_all_tests_dict
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Run PyTorch unit tests and generate xml results summary', formatter_class=argparse.RawTextHelpFormatter)
+ parser.add_argument('--test_config', nargs='+', default=[], type=str, help="space-separated list of test workflows to be executed eg. 'default distributed'")
+ parser.add_argument('--priority_tests', action='store_true', help="run priority tests only")
+ parser.add_argument('--default_list', nargs='+', default=[], help="space-separated list of 'default' config test suites/files to be executed eg. 'test_weak test_dlpack'")
+ parser.add_argument('--distributed_list', nargs='+', default=[], help="space-separated list of 'distributed' config test suites/files to be executed eg. 'distributed/test_c10d_common distributed/test_c10d_nccl'")
+ parser.add_argument('--inductor_list', nargs='+', default=[], help="space-separated list of 'inductor' config test suites/files to be executed eg. 'inductor/test_torchinductor test_ops'")
+ parser.add_argument('--pytorch_root', default='.', type=str, help="PyTorch root directory")
+ parser.add_argument('--skip_rerun', action='store_true', help="skip rerun process")
+ parser.add_argument('--example_output', type=str, help="{'workflow_name': {\n"
+ " test_file_and_status(file_name='workflow_aggregate', status='STATISTICS'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='ERROR'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='FAILED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='PASSED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='SKIPPED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='STATISTICS'): {} \n"
+ "}}\n")
+ parser.add_argument('--example_usages', type=str, help="RUN ALL TESTS: python3 run_pytorch_unit_tests.py \n"
+ "RUN PRIORITY TESTS: python3 run_pytorch_unit_tests.py --test_config distributed --priority_test \n"
+ "RUN SELECTED TESTS: python3 run_pytorch_unit_tests.py --default_list test_weak test_dlpack --inductor_list inductor/test_torchinductor")
+ return parser.parse_args()
+
+def check_num_gpus_for_distributed():
+ p = subprocess.run("rocminfo | grep -cE 'Name:\s+gfx'", shell=True, capture_output=True, text=True)
+ num_gpus_visible = int(p.stdout)
+ assert num_gpus_visible > 1, "Number of visible GPUs should be >1 to run distributed unit tests"
+
+def main():
+ args = parse_args()
+ all_tests_results = run_test_and_summarize_results(args.pytorch_root, args.priority_tests, args.test_config, args.default_list, args.distributed_list, args.inductor_list, args.skip_rerun)
+ pprint(dict(all_tests_results))
+
+if __name__ == "__main__":
+ main()
diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt
index 10f1207e60e6c..d893bdd32ab34 100644
--- a/.ci/docker/ci_commit_pins/triton.txt
+++ b/.ci/docker/ci_commit_pins/triton.txt
@@ -1 +1 @@
-7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd
+ac80c4190aa0321f761a08af97e1e1eee41f01d9
diff --git a/.ci/docker/common/install_cache.sh b/.ci/docker/common/install_cache.sh
index f38cb3d06d88b..80839990e4e6f 100644
--- a/.ci/docker/common/install_cache.sh
+++ b/.ci/docker/common/install_cache.sh
@@ -36,7 +36,12 @@ sed -e 's|PATH="\(.*\)"|PATH="/opt/cache/bin:\1"|g' -i /etc/environment
export PATH="/opt/cache/bin:$PATH"
# Setup compiler cache
-install_ubuntu
+if [ -n "$ROCM_VERSION" ]; then
+ curl --retry 3 http://repo.radeon.com/misc/.sccache_amd/sccache -o /opt/cache/bin/sccache
+else
+ install_ubuntu
+fi
+
chmod a+x /opt/cache/bin/sccache
function write_sccache_stub() {
diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh
index 1b68e3c247839..b2fdebdcc4747 100755
--- a/.ci/docker/common/install_triton.sh
+++ b/.ci/docker/common/install_triton.sh
@@ -21,7 +21,7 @@ elif [ -n "${TRITON_CPU}" ]; then
TRITON_REPO="https://github.com/triton-lang/triton-cpu"
TRITON_TEXT_FILE="triton-cpu"
else
- TRITON_REPO="https://github.com/triton-lang/triton"
+ TRITON_REPO="https://github.com/ROCm/triton"
TRITON_TEXT_FILE="triton"
fi
diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt
index d44dfb1ed67ae..93d32b803b199 100644
--- a/.ci/docker/requirements-ci.txt
+++ b/.ci/docker/requirements-ci.txt
@@ -117,10 +117,10 @@ ninja==1.11.1.4
#Pinned versions: 1.11.1.4
#test that import: run_test.py, test_cpp_extensions_aot.py,test_determination.py
-numba==0.55.2 ; python_version == "3.10" and platform_machine != "s390x"
-numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x"
+numba==0.60.0 ; python_version == "3.9"
+numba==0.61.2 ; python_version > "3.9"
#Description: Just-In-Time Compiler for Numerical Functions
-#Pinned versions: 0.55.2, 0.60.0
+#Pinned versions: 0.54.1, 0.49.0, <=0.49.1
#test that import: test_numba_integration.py
#Need release > 0.61.2 for s390x due to https://github.com/numba/numba/pull/10073
@@ -136,12 +136,10 @@ numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x"
#test_nn.py, test_namedtensor.py, test_linalg.py, test_jit_cuda_fuser.py,
#test_jit.py, test_indexing.py, test_datapipe.py, test_dataloader.py,
#test_binary_ufuncs.py
-numpy==1.22.4; python_version == "3.10"
-numpy==1.26.2; python_version == "3.11" or python_version == "3.12"
-numpy==2.1.2; python_version >= "3.13"
+numpy==2.0.2 ; python_version == "3.9"
+numpy==2.1.2 ; python_version > "3.9"
-pandas==2.0.3; python_version < "3.13"
-pandas==2.2.3; python_version >= "3.13"
+pandas==2.2.3
#onnxruntime
#Description: scoring engine for Open Neural Network Exchange (ONNX) models
@@ -251,8 +249,8 @@ scikit-image==0.22.0
#Pinned versions: 0.20.3
#test that import:
-scipy==1.10.1 ; python_version <= "3.11"
-scipy==1.14.1 ; python_version >= "3.12"
+scipy==1.13.1 ; python_version == "3.9"
+scipy==1.14.1 ; python_version > "3.9"
# Pin SciPy because of failing distribution tests (see #60347)
#Description: scientific python
#Pinned versions: 1.10.1
@@ -311,8 +309,7 @@ z3-solver==4.15.1.0 ; platform_machine != "s390x"
#Pinned versions:
#test that import:
-tensorboard==2.13.0 ; python_version < "3.13"
-tensorboard==2.18.0 ; python_version >= "3.13"
+tensorboard==2.18.0
#Description: Also included in .ci/docker/requirements-docs.txt
#Pinned versions:
#test that import: test_tensorboard
diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh
index ff9d8ad41cc92..9c9d223777466 100644
--- a/.ci/pytorch/common_utils.sh
+++ b/.ci/pytorch/common_utils.sh
@@ -67,13 +67,13 @@ function pip_install_whl() {
# Loop through each path and install individually
for path in "${paths[@]}"; do
echo "Installing $path"
- python3 -mpip install --no-index --no-deps "$path"
+ python3 -mpip install "$path"
done
else
# Loop through each argument and install individually
for path in "${args[@]}"; do
echo "Installing $path"
- python3 -mpip install --no-index --no-deps "$path"
+ python3 -mpip install "$path"
done
fi
}
diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh
index 11f9678579935..e64a690af1d6a 100755
--- a/.circleci/scripts/binary_populate_env.sh
+++ b/.circleci/scripts/binary_populate_env.sh
@@ -69,48 +69,6 @@ fi
export PYTORCH_BUILD_NUMBER=1
-# Set triton version as part of PYTORCH_EXTRA_INSTALL_REQUIREMENTS
-TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt)
-TRITON_CONSTRAINT="platform_system == 'Linux'"
-
-if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" && ! "$PYTORCH_BUILD_VERSION" =~ .*xpu.* ]]; then
- TRITON_REQUIREMENT="triton==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"
- if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then
- TRITON_SHORTHASH=$(cut -c1-8 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt)
- TRITON_REQUIREMENT="pytorch-triton==${TRITON_VERSION}+git${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"
- fi
- export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${TRITON_REQUIREMENT}"
-fi
-
-# Set triton via PYTORCH_EXTRA_INSTALL_REQUIREMENTS for triton rocm package
-if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*rocm.* && $(uname) == "Linux" ]]; then
- TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}; ${TRITON_CONSTRAINT}"
- if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then
- TRITON_SHORTHASH=$(cut -c1-8 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt)
- TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}+git${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}"
- fi
- if [[ -z "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then
- export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${TRITON_REQUIREMENT}"
- else
- export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${TRITON_REQUIREMENT}"
- fi
-fi
-
-# Set triton via PYTORCH_EXTRA_INSTALL_REQUIREMENTS for triton xpu package
-if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*xpu.* ]]; then
- TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_xpu_version.txt)
- TRITON_REQUIREMENT="pytorch-triton-xpu==${TRITON_VERSION}"
- if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then
- TRITON_SHORTHASH=$(cut -c1-8 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-xpu.txt)
- TRITON_REQUIREMENT="pytorch-triton-xpu==${TRITON_VERSION}+git${TRITON_SHORTHASH}"
- fi
- if [[ -z "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then
- export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${TRITON_REQUIREMENT}"
- else
- export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${TRITON_REQUIREMENT}"
- fi
-fi
-
USE_GLOO_WITH_OPENSSL="ON"
if [[ "$GPU_ARCH_TYPE" =~ .*aarch64.* ]]; then
USE_GLOO_WITH_OPENSSL="OFF"
diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py
index 11fa8404273d3..e541e7a86f653 100644
--- a/.github/scripts/build_triton_wheel.py
+++ b/.github/scripts/build_triton_wheel.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
import os
+import re
import shutil
import sys
from pathlib import Path
@@ -50,6 +51,31 @@ def patch_init_py(
with open(path, "w") as f:
f.write(orig)
+def get_rocm_version() -> str:
+ rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm"
+ rocm_version = "0.0.0"
+ rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h"
+ if not os.path.isfile(rocm_version_h):
+ rocm_version_h = f"{rocm_path}/include/rocm_version.h"
+
+ # The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
+ if os.path.isfile(rocm_version_h):
+ RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)")
+ RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)")
+ RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)")
+ major, minor, patch = 0, 0, 0
+ for line in open(rocm_version_h):
+ match = RE_MAJOR.search(line)
+ if match:
+ major = int(match.group(1))
+ match = RE_MINOR.search(line)
+ if match:
+ minor = int(match.group(1))
+ match = RE_PATCH.search(line)
+ if match:
+ patch = int(match.group(1))
+ rocm_version = str(major)+"."+str(minor)+"."+str(patch)
+ return rocm_version
def build_triton(
*,
@@ -65,13 +91,22 @@ def build_triton(
max_jobs = os.cpu_count() or 1
env["MAX_JOBS"] = str(max_jobs)
+ version_suffix = ""
+ if not release:
+ # Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
+ # while release build should only include the version, i.e. 2.1.0
+ rocm_version = get_rocm_version()
+ version_suffix = f"+rocm{rocm_version}.git{commit_hash[:8]}"
+ version += version_suffix
+
with TemporaryDirectory() as tmpdir:
triton_basedir = Path(tmpdir) / "triton"
triton_pythondir = triton_basedir / "python"
triton_repo = "https://github.com/openai/triton"
if device == "rocm":
- triton_pkg_name = "pytorch-triton-rocm"
+ triton_pkg_name = "triton"
+ triton_repo = "https://github.com/ROCm/triton"
elif device == "xpu":
triton_pkg_name = "pytorch-triton-xpu"
triton_repo = "https://github.com/intel/intel-xpu-backend-for-triton"
@@ -89,6 +124,7 @@ def build_triton(
# change built wheel name and version
env["TRITON_WHEEL_NAME"] = triton_pkg_name
+ env["TRITON_WHEEL_VERSION_SUFFIX"] = version_suffix
if with_clang_ldd:
env["TRITON_BUILD_WITH_CLANG_LLD"] = "1"
@@ -127,6 +163,13 @@ def build_triton(
cwd=triton_basedir,
)
+ # For gpt-oss models, triton requires this extra triton_kernels wheel
+ # triton_kernels came after pytorch release/2.8
+ triton_kernels_dir = Path(f"{triton_basedir}/python/triton_kernels")
+ check_call([sys.executable, "-m", "build", "--wheel"], cwd=triton_kernels_dir, env=env)
+ kernels_whl_path = next(iter((triton_kernels_dir / "dist").glob("*.whl")))
+ shutil.copy(kernels_whl_path, Path.cwd())
+
return Path.cwd() / whl_path.name
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0b88247df27a5..991ea336a175b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -56,11 +56,11 @@ set(CMAKE_C_STANDARD
# ---[ Utils
include(cmake/public/utils.cmake)
-# --- [ Check that minimal gcc version is 9.3+
-if(CMAKE_COMPILER_IS_GNUCXX AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.3)
+# --- [ Check that minimal gcc version is 9.2+
+if(CMAKE_COMPILER_IS_GNUCXX AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.2)
message(
FATAL_ERROR
- "GCC-9.3 or newer is required to compile PyTorch, but found ${CMAKE_CXX_COMPILER_VERSION}"
+ "GCC-9.2 or newer is required to compile PyTorch, but found ${CMAKE_CXX_COMPILER_VERSION}"
)
endif()
diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp
index 267d1f5acea52..5b28cc6eccf01 100644
--- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp
+++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp
@@ -332,11 +332,11 @@ static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S
// gesvd just knows how to handle m >= n, so in the other case we need to transpose A
const auto not_A_H = A.size(-2) >= A.size(-1);
Tensor Vcopy = V; // Shallow copy
-#ifdef USE_ROCM
+#ifdef ROCM_VERSION
// Similar to the case in svd_magma(), experiments have shown Vh tensor is
// not guaranteed to be column major on ROCM, we have to create a copy to
// deal with this
- if (!not_A_H) {
+ if (compute_uv && !not_A_H) {
Vcopy = at::empty_like(V.mT(),
V.options()
.device(V.device())
@@ -351,8 +351,8 @@ static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S
infos,
full_matrices, compute_uv, calculate_all_batches, batches);
});
-#ifdef USE_ROCM
- if (!not_A_H) {
+#ifdef ROCM_VERSION
+ if (compute_uv && !not_A_H) {
V.copy_(Vcopy);
}
#endif
@@ -526,8 +526,8 @@ static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const T
template
static void apply_svd_cusolver_gesvdaStridedBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
const Tensor& infos, bool full_matrices, bool compute_uv) {
-#ifndef CUDART_VERSION
- TORCH_CHECK(false, "gesvda: Batched version is supported only with cuBLAS backend.")
+#if defined(CUDART_VERSION) || defined(USE_ROCM) && ROCM_VERSION < 60100
+ TORCH_CHECK(false, "gesvda: Batched version is supported only with cuBLAS backend or ROCM >= 5.7.0.")
#else
using value_t = typename c10::scalar_value_type::type;
int m = cuda_int_cast(A.size(-2), "m");
@@ -665,7 +665,7 @@ void svd_cusolver(const Tensor& A,
static constexpr const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
// The default heuristic is to use gesvdj driver
-#ifdef USE_ROCM
+#if defined(ROCM_VERSION) && ROCM_VERSION < 60100
const auto driver_v = std::string_view("gesvdj");
#else
const auto driver_v = driver.value_or("gesvdj");
diff --git a/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp b/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp
index 99c38077611d6..af183038bb8e4 100644
--- a/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp
+++ b/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp
@@ -470,8 +470,8 @@ void gesvdjBatched>(
}
-// ROCM does not implement gesdva yet
-#ifdef CUDART_VERSION
+// ROCM does not implement gesdva correctly before 6.1
+#if defined(CUDART_VERSION) || defined(ROCM_VERSION) && ROCM_VERSION >= 60100
template<>
void gesvdaStridedBatched_buffersize(
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, float *A, int lda, long long int strideA,
diff --git a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu
index 49bea10c65104..8402555a5c340 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu
+++ b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu
@@ -40,7 +40,27 @@
#include
+#if defined(__CUDACC__) && (defined(CUSPARSE_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60300))
+#define IS_CUSPARSE11_AVAILABLE() 1
+#else
+#define IS_CUSPARSE11_AVAILABLE() 0
+#endif
+
+#if defined(USE_ROCM) && (ROCM_VERSION >= 70000)
+#define HIPSPARSE_FP16_SUPPORT 1
+#else
+#define HIPSPARSE_FP16_SUPPORT 0
+#endif
+
+#if defined(USE_ROCM) && (ROCM_VERSION >= 70100)
+#define HIPSPARSE_FP16_BF16_SUPPORT 1
+#else
+#define HIPSPARSE_FP16_BF16_SUPPORT 0
+#endif
+
+#if IS_CUSPARSE11_AVAILABLE()
#include
+#endif
namespace at::native {
diff --git a/aten/src/ATen/test/cuda_vectorized_test.cu b/aten/src/ATen/test/cuda_vectorized_test.cu
index e4c18102526ac..1b3ed4dc4ac42 100644
--- a/aten/src/ATen/test/cuda_vectorized_test.cu
+++ b/aten/src/ATen/test/cuda_vectorized_test.cu
@@ -32,23 +32,6 @@ void reset_buffers() {
}
}
-#if defined(USE_ROCM) && !defined(_WIN32)
-TEST(TestLoops, HasSameArgTypes) {
- // This is a compile-time unit test. If this file compiles without error,
- // then the test passes and during runtime, we just need to return.
- using namespace at::native::modern::detail;
- using func1_t = int (*)(float, float);
- using func2_t = int (*)(bool, float, float);
- using func3_t = int (*)(float);
- using func4_t = int (*)();
- static_assert(has_same_arg_types::value, "func1_t has the same argument types");
- static_assert(!has_same_arg_types::value, "func2_t does not have the same argument types");
- static_assert(has_same_arg_types::value, "func3_t has the same argument types");
- static_assert(has_same_arg_types::value, "func4_t has the same argument types");
- return;
-}
-#endif
-
TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
char *ptr = reinterpret_cast(buffer1);
diff --git a/related_commits b/related_commits
new file mode 100644
index 0000000000000..ee36e55601d0f
--- /dev/null
+++ b/related_commits
@@ -0,0 +1,10 @@
+ubuntu|pytorch|apex|master|2190fbaeb88384ed792373adbb83c182af117ca0|https://github.com/ROCm/apex
+centos|pytorch|apex|master|2190fbaeb88384ed792373adbb83c182af117ca0|https://github.com/ROCm/apex
+ubuntu|pytorch|torchvision|main|218d2ab791d437309f91e0486eb9fa7f00badc17|https://github.com/pytorch/vision
+centos|pytorch|torchvision|main|218d2ab791d437309f91e0486eb9fa7f00badc17|https://github.com/pytorch/vision
+ubuntu|pytorch|torchdata|main|92950795e0790eb74df995daf40b658e85fd2c9f|https://github.com/pytorch/data
+centos|pytorch|torchdata|main|92950795e0790eb74df995daf40b658e85fd2c9f|https://github.com/pytorch/data
+ubuntu|pytorch|torchaudio|main|3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2|https://github.com/pytorch/audio
+centos|pytorch|torchaudio|main|3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2|https://github.com/pytorch/audio
+ubuntu|pytorch|ao|main|3577306c8b32517afe8eb6eb7e84335601180598|https://github.com/pytorch/ao
+centos|pytorch|ao|main|3577306c8b32517afe8eb6eb7e84335601180598|https://github.com/pytorch/ao
diff --git a/requirements-build.txt b/requirements-build.txt
index 85923ae39cbdb..f2edf387fb97a 100644
--- a/requirements-build.txt
+++ b/requirements-build.txt
@@ -1,11 +1,12 @@
# Build System requirements
-setuptools>=70.1.0
-cmake>=3.27
-ninja
-numpy
-packaging
-pyyaml
-requests
-six # dependency chain: NNPACK -> PeachPy -> six
-typing-extensions>=4.10.0
pip # not technically needed, but this makes setup.py invocation work
+setuptools>=70.1.0,<80.0 # setuptools develop deprecated on 80.0
+cmake>=3.31.4
+ninja==1.11.1.3
+numpy==2.0.2 ; python_version == "3.9"
+numpy==2.1.2 ; python_version > "3.9"
+packaging==25.0
+pyyaml==6.0.2
+requests==2.32.4
+six==1.17.0 # dependency chain: NNPACK -> PeachPy -> six
+typing-extensions==4.14.1
diff --git a/requirements.txt b/requirements.txt
index fc4b53dfd49ea..090a733726658 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,15 +5,18 @@
# Install / Development extra requirements
build[uv] # for building sdist and wheel
-expecttest>=0.3.0
-filelock
-fsspec>=0.8.5
-hypothesis
-jinja2
-lintrunner ; platform_machine != "s390x" and platform_machine != "riscv64"
-networkx>=2.5.1
-optree>=0.13.0
-psutil
-sympy>=1.13.3
-typing-extensions>=4.13.2
-wheel
+expecttest==0.3.0
+filelock==3.18.0
+fsspec==2025.7.0
+hypothesis==5.35.1
+jinja2==3.1.6
+lintrunner==0.12.7 ; platform_machine != "s390x"
+networkx==2.8.8
+ninja==1.11.1.3
+numpy==2.0.2 ; python_version == "3.9"
+numpy==2.1.2 ; python_version > "3.9"
+optree==0.13.0
+psutil==7.0.0
+sympy==1.13.3
+typing-extensions==4.14.1
+wheel==0.45.1
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 5150dab4b7cf1..eb6877b419d0b 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -69,6 +69,12 @@ def _op_supports_any_sparse(op):
) or (not IS_WINDOWS and not TEST_WITH_ROCM)
HIPSPARSE_SPMM_COMPLEX128_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("6.0")
+HIPSPARSE_FP16_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("7.0")
+HIPSPARSE_BF16_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("7.1")
+
+SPARSE_COMPLEX128_SUPPORTED = CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
+SPARSE_FLOAT16_SUPPORTED = (SM53OrLater and torch.version.cuda) or (HIPSPARSE_FP16_SUPPORTED)
+SPARSE_BFLOAT16_SUPPORTED = (SM80OrLater and torch.version.cuda) or (HIPSPARSE_BF16_SUPPORTED)
def all_sparse_layouts(test_name='layout', include_strided=False):
return parametrize(test_name, [
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index f84adcc7bd262..e1bfd3f146991 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -25,7 +25,8 @@
all_types_and_complex, floating_and_complex_types_and)
from torch.testing._internal.opinfo.definitions.linalg import sample_inputs_linalg_solve
from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
-from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
+from test_sparse import HIPSPARSE_BF16_SUPPORTED, HIPSPARSE_FP16_SUPPORTED, \
+ SPARSE_FLOAT16_SUPPORTED, SPARSE_BFLOAT16_SUPPORTED, SPARSE_COMPLEX128_SUPPORTED
import operator
if TEST_SCIPY:
@@ -1940,8 +1941,8 @@ def test_shape(d1, d2, d3, nnz, transposed, index_dtype):
@dtypes(*floating_and_complex_types())
@dtypesIfCUDA(*floating_and_complex_types_and(
- *[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [],
- *[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else []))
+ *[torch.half] if SPARSE_FLOAT16_SUPPORTED else [],
+ *[torch.bfloat16] if SPARSE_BFLOAT16_SUPPORTED else []))
@precisionOverride({torch.bfloat16: 3.5e-2, torch.float16: 1e-2})
def test_sparse_addmm(self, device, dtype):
def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None):
diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py
index 74dfe0c56c232..3f475bd6823b5 100644
--- a/torch/testing/_internal/common_cuda.py
+++ b/torch/testing/_internal/common_cuda.py
@@ -192,6 +192,9 @@ def tf32_off():
@contextlib.contextmanager
def tf32_on(self, tf32_precision=1e-5):
+ if torch.version.hip:
+ hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
+ os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
old_precision = self.precision
try:
@@ -200,6 +203,11 @@ def tf32_on(self, tf32_precision=1e-5):
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
yield
finally:
+ if torch.version.hip:
+ if hip_allow_tf32 is not None:
+ os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
+ else:
+ del os.environ["HIPBLASLT_ALLOW_TF32"]
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
self.precision = old_precision
diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py
index 82547c8e28540..12e1a1209c2cd 100644
--- a/torch/utils/hipify/cuda_to_hip_mappings.py
+++ b/torch/utils/hipify/cuda_to_hip_mappings.py
@@ -8593,6 +8593,14 @@
"CUSPARSE_STATUS_ZERO_PIVOT",
("HIPSPARSE_STATUS_ZERO_PIVOT", CONV_NUMERIC_LITERAL, API_SPECIAL),
),
+ (
+ "CUSPARSE_STATUS_NOT_SUPPORTED",
+ ("HIPSPARSE_STATUS_NOT_SUPPORTED", CONV_NUMERIC_LITERAL, API_SPECIAL),
+ ),
+ (
+ "CUSPARSE_STATUS_INSUFFICIENT_RESOURCES",
+ ("HIPSPARSE_STATUS_INSUFFICIENT_RESOURCES", CONV_NUMERIC_LITERAL, API_SPECIAL),
+ ),
(
"CUSPARSE_OPERATION_TRANSPOSE",
("HIPSPARSE_OPERATION_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPECIAL),