Skip to content

Commit 4e5b5b7

Browse files
author
Ruo-Ping Dong
authored
Add global lock for torch.onnx.export() (#4665)
* Cherry-pick fix from #4659
1 parent 60ae629 commit 4e5b5b7

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

ml-agents/mlagents/trainers/torch/model_serialization.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,20 @@ class exporting_to_onnx:
1919
This implementation is thread safe.
2020
"""
2121

22+
# local is_exporting flag for each thread
2223
_local_data = threading.local()
2324
_local_data._is_exporting = False
2425

26+
# global lock shared among all threads, to make sure only one thread is exporting at a time
27+
_lock = threading.Lock()
28+
2529
def __enter__(self):
30+
self._lock.acquire()
2631
self._local_data._is_exporting = True
2732

2833
def __exit__(self, *args):
2934
self._local_data._is_exporting = False
35+
self._lock.release()
3036

3137
@staticmethod
3238
def is_exporting():

0 commit comments

Comments
 (0)