Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tf/tfprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,15 @@ def __init__(self, cfg):
self.renorm_max_d = self.cfg['training'].get('renorm_max_d', 0)
self.renorm_momentum = self.cfg['training'].get(
'renorm_momentum', 0.99)

if self.cfg['gpu'] == 'all':
self.strategy = tf.distribute.MirroredStrategy()
tf.distribute.experimental_set_strategy(self.strategy)
elif self.cfg['gpu'].__contains__(','):
gpus = tf.config.experimental.list_physical_devices('GPU')
for i in self.cfg['gpu'].split(","):
tf.config.experimental.set_visible_devices(gpus[int(i)],
'GPU')
else:
gpus = tf.config.experimental.list_physical_devices('GPU')
print(gpus)
Expand Down