Skip to content

Commit 7711ffe

Browse files
dlfw related changes (#3814)
1 parent ad0a7b1 commit 7711ffe

File tree

4 files changed

+59
-100
lines changed

4 files changed

+59
-100
lines changed

setup.py

Lines changed: 51 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,26 @@
3131
__tensorrt_rtx_version__: str = "0.0"
3232

3333
LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")
34+
# CI_PIPELINE_ID is the environment variable set by DLFW ci build
35+
IS_DLFW_CI = os.environ.get("CI_PIPELINE_ID") is not None
3436

3537

3638
def get_root_dir() -> Path:
37-
return Path(
38-
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
39-
.decode("ascii")
40-
.strip()
41-
)
39+
return Path(__file__).parent.absolute()
4240

4341

4442
def get_git_revision_short_hash() -> str:
45-
return (
46-
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
47-
.decode("ascii")
48-
.strip()
49-
)
43+
# DLFW ci build does not have git
44+
try:
45+
return (
46+
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
47+
.decode("ascii")
48+
.strip()
49+
)
50+
except:
51+
print("WARNING: Could not get git revision short hash, using default one")
52+
# in release/ngc/25.10 branch this is the commit hash of the pytorch commit that is used for dlfw package
53+
return "0000000"
5054

5155

5256
def get_base_version() -> str:
@@ -718,58 +722,57 @@ def run(self):
718722
with open(os.path.join(get_root_dir(), "README.md"), "r", encoding="utf-8") as fh:
719723
long_description = fh.read()
720724

725+
base_requirements = [
726+
"packaging>=23",
727+
"typing-extensions>=4.7.0",
728+
"dllist",
729+
]
721730

722-
def get_requirements():
723-
requirements = [
724-
"packaging>=23",
725-
"typing-extensions>=4.7.0",
726-
"dllist",
727-
]
728731

732+
def get_requirements():
729733
if IS_JETPACK:
730-
requirements.extend(
731-
[
732-
"torch>=2.8.0,<2.9.0",
733-
"tensorrt>=10.3.0,<10.4.0",
734-
"numpy<2.0.0",
735-
]
736-
)
734+
requirements = get_jetpack_requirements()
737735
elif IS_SBSA:
738-
requirements.extend(
739-
[
740-
"torch>=2.9.0.dev,<2.10.0",
741-
"tensorrt>=10.12.0,<10.13.0",
742-
"tensorrt-cu12>=10.12.0,<10.13.0",
743-
"tensorrt-cu12-bindings>=10.12.0,<10.13.0",
744-
"tensorrt-cu12-libs>=10.12.0,<10.13.0",
745-
"numpy",
746-
]
747-
)
736+
requirements = get_sbsa_requirements()
748737
else:
749-
requirements.extend(
750-
[
751-
"torch>=2.9.0.dev,<2.10.0",
752-
"numpy",
753-
]
754-
)
755-
if USE_TRT_RTX:
756-
requirements.extend(
757-
[
758-
"tensorrt-rtx>=1.0.0.21",
738+
# standard linux and windows requirements
739+
requirements = base_requirements + ["numpy"]
740+
if not IS_DLFW_CI:
741+
requirements = requirements + ["torch>=2.9.0.dev,<2.10.0"]
742+
if USE_TRT_RTX:
743+
requirements = requirements + [
744+
"tensorrt_rtx>=1.0.0.21",
759745
]
760-
)
761-
else:
762-
requirements.extend(
763-
[
746+
else:
747+
requirements = requirements + [
764748
"tensorrt>=10.12.0,<10.13.0",
765749
"tensorrt-cu12>=10.12.0,<10.13.0",
766750
"tensorrt-cu12-bindings>=10.12.0,<10.13.0",
767751
"tensorrt-cu12-libs>=10.12.0,<10.13.0",
768752
]
769-
)
770753
return requirements
771754

772755

756+
def get_jetpack_requirements():
757+
jetpack_requirements = base_requirements + ["numpy<2.0.0"]
758+
if IS_DLFW_CI:
759+
return jetpack_requirements
760+
return jetpack_requirements + ["torch>=2.8.0,<2.9.0", "tensorrt>=10.3.0,<10.4.0"]
761+
762+
763+
def get_sbsa_requirements():
764+
sbsa_requirements = base_requirements + ["numpy"]
765+
if IS_DLFW_CI:
766+
return sbsa_requirements
767+
return sbsa_requirements + [
768+
"torch>=2.9.0.dev,<2.10.0",
769+
"tensorrt>=10.12.0,<10.13.0",
770+
"tensorrt-cu12>=10.12.0,<10.13.0",
771+
"tensorrt-cu12-bindings>=10.12.0,<10.13.0",
772+
"tensorrt-cu12-libs>=10.12.0,<10.13.0",
773+
]
774+
775+
773776
setup(
774777
name="torch_tensorrt",
775778
ext_modules=ext_modules,

tests/modules/hub.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -39,53 +39,6 @@
3939
# "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"},
4040
}
4141

42-
if importlib.util.find_spec("torchvision"):
43-
import timm
44-
import torchvision.models as models
45-
46-
torchvision_models = {
47-
"alexnet": {"model": models.alexnet(pretrained=True), "path": "both"},
48-
"vgg16": {"model": models.vgg16(pretrained=True), "path": "both"},
49-
"squeezenet": {"model": models.squeezenet1_0(pretrained=True), "path": "both"},
50-
"densenet": {"model": models.densenet161(pretrained=True), "path": "both"},
51-
"inception_v3": {"model": models.inception_v3(pretrained=True), "path": "both"},
52-
"shufflenet": {
53-
"model": models.shufflenet_v2_x1_0(pretrained=True),
54-
"path": "both",
55-
},
56-
"mobilenet_v2": {"model": models.mobilenet_v2(pretrained=True), "path": "both"},
57-
"resnext50_32x4d": {
58-
"model": models.resnext50_32x4d(pretrained=True),
59-
"path": "both",
60-
},
61-
"wideresnet50_2": {
62-
"model": models.wide_resnet50_2(pretrained=True),
63-
"path": "both",
64-
},
65-
"mnasnet": {"model": models.mnasnet1_0(pretrained=True), "path": "both"},
66-
"resnet18": {
67-
"model": torch.hub.load(
68-
"pytorch/vision:v0.9.0", "resnet18", pretrained=True
69-
),
70-
"path": "both",
71-
},
72-
"resnet50": {
73-
"model": torch.hub.load(
74-
"pytorch/vision:v0.9.0", "resnet50", pretrained=True
75-
),
76-
"path": "both",
77-
},
78-
"efficientnet_b0": {
79-
"model": timm.create_model("efficientnet_b0", pretrained=True),
80-
"path": "script",
81-
},
82-
"vit": {
83-
"model": timm.create_model("vit_base_patch16_224", pretrained=True),
84-
"path": "script",
85-
},
86-
}
87-
to_test_models.update(torchvision_models)
88-
8942

9043
def get(n, m, manifest):
9144
print("Downloading {}".format(n))

tests/py/dynamo/models/test_models.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
assertions = unittest.TestCase()
1515

1616
if importlib.util.find_spec("torchvision"):
17-
import timm
1817
import torchvision.models as models
18+
if importlib.util.find_spec("timm"):
19+
import timm
1920

2021

2122
@pytest.mark.unit
@@ -132,11 +133,11 @@ def test_resnet18_torch_exec_ops(ir):
132133

133134

134135
@pytest.mark.unit
136+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
135137
@unittest.skipIf(
136138
not importlib.util.find_spec("torchvision"),
137139
"torchvision is not installed",
138140
)
139-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
140141
def test_mobilenet_v2(ir, dtype):
141142
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
142143
pytest.skip("TensorRT-RTX does not support bfloat16")
@@ -174,11 +175,11 @@ def test_mobilenet_v2(ir, dtype):
174175

175176

176177
@pytest.mark.unit
178+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
177179
@unittest.skipIf(
178180
not importlib.util.find_spec("timm") or not importlib.util.find_spec("torchvision"),
179181
"timm or torchvision not installed",
180182
)
181-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
182183
def test_efficientnet_b0(ir, dtype):
183184
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
184185
pytest.skip("TensorRT-RTX does not support bfloat16")
@@ -221,11 +222,11 @@ def test_efficientnet_b0(ir, dtype):
221222

222223

223224
@pytest.mark.unit
225+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
224226
@unittest.skipIf(
225227
not importlib.util.find_spec("transformers"),
226228
"transformers is required to run this test",
227229
)
228-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
229230
def test_bert_base_uncased(ir, dtype):
230231
if torchtrt.ENABLED_FEATURES.tensorrt_rtx and dtype == torch.bfloat16:
231232
pytest.skip("TensorRT-RTX does not support bfloat16")

tests/py/dynamo/models/test_models_export.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
from packaging.version import Version
1313

1414
if importlib.util.find_spec("torchvision"):
15-
import timm
1615
import torchvision.models as models
1716

17+
if importlib.util.find_spec("timm"):
18+
import timm
19+
1820
assertions = unittest.TestCase()
1921

2022

0 commit comments

Comments
 (0)