Skip to content

Commit 1103a0f

Browse files
authored
Fixes cuda version as float for AutoMate to correctly convert patch versions (#3795)
# Description To convert cuda version from a string to a float, I update the function to handle cases with multiple points, e.g. string '12.8.9' will be converted to float 12.89. Before, float('12.8.9') will return None for failure conversion. ## Type of change - Bug fix (non-breaking change which fixes an issue) ## Checklist - [x] I have read and understood the [contribution guidelines](https://isaac-sim.github.io/IsaacLab/main/source/refs/contributing.html) - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [ ] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [ ] I have added my name to the `CONTRIBUTORS.md` or my name already exists there
1 parent b70bd42 commit 1103a0f

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

source/isaaclab_tasks/isaaclab_tasks/direct/automate/assembly_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, cfg: AssemblyEnvCfg, render_mode: str | None = None, **kwargs
6161

6262
# Create criterion for dynamic time warping (later used for imitation reward)
6363
cuda_version = automate_algo.get_cuda_version()
64-
if (cuda_version is not None) and (cuda_version < 13.0):
64+
if (cuda_version is not None) and (cuda_version < (13, 0, 0)):
6565
self.soft_dtw_criterion = SoftDTW(use_cuda=True, device=self.device, gamma=self.cfg_task.soft_dtw_gamma)
6666
else:
6767
self.soft_dtw_criterion = SoftDTW(use_cuda=False, device=self.device, gamma=self.cfg_task.soft_dtw_gamma)

source/isaaclab_tasks/isaaclab_tasks/direct/automate/automate_algo_utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,28 @@
2525
"""
2626

2727

28+
def parse_cuda_version(version_string):
29+
"""
30+
Parse CUDA version string into comparable tuple of (major, minor, patch).
31+
32+
Args:
33+
version_string: Version string like "12.8.9" or "11.2"
34+
35+
Returns:
36+
Tuple of (major, minor, patch) as integers, where patch defaults to 0 iff
37+
not present.
38+
39+
Example:
40+
"12.8.9" -> (12, 8, 9)
41+
"11.2" -> (11, 2, 0)
42+
"""
43+
parts = version_string.split(".")
44+
major = int(parts[0])
45+
minor = int(parts[1]) if len(parts) > 1 else 0
46+
patch = int(parts[2]) if len(parts) > 2 else 0
47+
return (major, minor, patch)
48+
49+
2850
def get_cuda_version():
2951
try:
3052
# Execute nvcc --version command
@@ -34,7 +56,7 @@ def get_cuda_version():
3456
# Use regex to find the CUDA version (e.g., V11.2.67)
3557
match = re.search(r"V(\d+\.\d+(\.\d+)?)", output)
3658
if match:
37-
return float(match.group(1))
59+
return parse_cuda_version(match.group(1))
3860
else:
3961
print("CUDA version not found in output.")
4062
return None

0 commit comments

Comments
 (0)