File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -1459,14 +1459,15 @@ def make_rotary_embedding_multi_cache(self, **kwargs):
1459
1459
self .rope_attrs ["save_caches" ] = False
1460
1460
cos_cache_small , sin_cache_small = self .make_rotary_embedding_caches (cos_cache_name = cos_cache_small_name , sin_cache_name = sin_cache_small_name )
1461
1461
1462
- if self .ep == "dml" :
1463
- # Concat small and large cos/sin caches for DML EP only
1462
+ if self .ep in ["dml" , "NvTensorRtRtx" ]:
1463
+ # Concat small and large cos/sin caches for DML and NvTensorRtRtx EPs
1464
+ # These EPs don't support the If operator
1464
1465
cos_cache = torch .cat ((cos_cache_small , cos_cache_large ), dim = 0 )
1465
1466
sin_cache = torch .cat ((sin_cache_small , sin_cache_large ), dim = 0 )
1466
1467
# Save cos/sin caches to disk
1467
1468
self .make_initializer (cos_cache , cos_cache_name )
1468
1469
self .make_initializer (sin_cache , sin_cache_name )
1469
- # Do NOT make the subgraph with the If node for DML EP .
1470
+ # Do NOT make the subgraph with the If node for these EPs .
1470
1471
return
1471
1472
1472
1473
# Make the following subgraph to decide which cos/sin caches to use in the rotary embeddings
You can’t perform that action at this time.
0 commit comments