Skip to content

Commit 768bd68

Browse files
Merge pull request NVIDIA#268 from szmigacz/inprocess_no_pickle
Inprocess: states and keys are serialized with json, not pickle
2 parents b1b1003 + c47d7b3 commit 768bd68

File tree

1 file changed

+15
-23
lines changed
  • src/nvidia_resiliency_ext/inprocess

1 file changed

+15
-23
lines changed

src/nvidia_resiliency_ext/inprocess/store.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,13 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import dataclasses
1718
import datetime
1819
import functools
1920
import inspect
21+
import json
2022
import logging
2123
import os
22-
23-
# Issue: [B403:blacklist] Consider possible security implications associated with pickle module.
24-
# Severity: Low Confidence: High
25-
# CWE: CWE-502 (https://cwe.mitre.org/data/definitions/502.html)
26-
# More Info: https://bandit.readthedocs.io/en/1.8.3/blacklists/blacklist_imports.html#b403-import-pickle
27-
import pickle # nosec
2824
import sys
2925
import time
3026
from collections.abc import Iterable
@@ -36,7 +32,7 @@
3632

3733
from . import exception, utils
3834
from .attribution import InterruptionRecord
39-
from .state import Mode
35+
from .state import Mode, State
4036

4137

4238
class BarrierError(exception.RestartError):
@@ -109,30 +105,26 @@ def send_heartbeat(self, rank: int):
109105
self.set(self.HEARTBEAT.format(rank=rank), str(time.time()))
110106

111107
def send_state(self, state, rank: int):
112-
self.set(self.STATE.format(rank=rank), pickle.dumps(state))
108+
state_dict = dataclasses.asdict(state)
109+
state_dict['mode'] = state.mode.name
110+
state_dict['fn_exception'] = None
111+
self.set(self.STATE.format(rank=rank), json.dumps(state_dict))
113112

114113
def send_key(self, key, rank: int):
115-
self.set(self.KEY.format(rank=rank), pickle.dumps(key))
114+
self.set(self.KEY.format(rank=rank), json.dumps(key))
116115

117116
def get_states(self, ranks):
118-
states = [
119-
# Issue: [B301:blacklist] Pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue.
120-
# Severity: Medium Confidence: High
121-
# CWE: CWE-502 (https://cwe.mitre.org/data/definitions/502.html)
122-
# More Info: https://bandit.readthedocs.io/en/1.8.3/blacklists/blacklist_calls.html#b301-pickle
123-
pickle.loads(state) # nosec
124-
for state in self.multi_get([self.STATE.format(rank=rank) for rank in ranks])
125-
]
117+
states = []
118+
for data in self.multi_get([self.STATE.format(rank=rank) for rank in ranks]):
119+
state_dict = json.loads(data)
120+
state_dict['mode'] = Mode[state_dict['mode']]
121+
states.append(State(**state_dict))
126122
return states
127123

128124
def get_keys(self, ranks):
129125
keys = [
130-
# Issue: [B301:blacklist] Pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue.
131-
# Severity: Medium Confidence: High
132-
# CWE: CWE-502 (https://cwe.mitre.org/data/definitions/502.html)
133-
# More Info: https://bandit.readthedocs.io/en/1.8.3/blacklists/blacklist_calls.html#b301-pickle
134-
pickle.loads(key) # nosec
135-
for key in self.multi_get([self.KEY.format(rank=rank) for rank in ranks])
126+
json.loads(data)
127+
for data in self.multi_get([self.KEY.format(rank=rank) for rank in ranks])
136128
]
137129
return keys
138130

0 commit comments

Comments
 (0)