Skip to content

Commit e33da24

Browse files
authored
fix bug of SeqAugment layer & and compatibility problem of Windows Operation System (#520)
* fix bug of SeqAugment layer * add compacity for windows users
1 parent bc38227 commit e33da24

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

easy_rec/python/layers/keras/data_augment.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,28 @@ def mask_fn():
7575
def reorder_fn():
7676
return item_reorder(seq, length, aug_param.reorder_rate)
7777

78-
method = tf.random.uniform([], minval=0, maxval=3, dtype=tf.int32)
78+
trans_fn = []
79+
if aug_param.crop_rate < 1.0:
80+
trans_fn.append(crop_fn)
81+
if aug_param.mask_rate > 0:
82+
trans_fn.append(mask_fn)
83+
if aug_param.reorder_rate > 0:
84+
trans_fn.append(reorder_fn)
85+
86+
num_trans = len(trans_fn)
87+
if num_trans == 0:
88+
return seq, length
89+
90+
if num_trans == 1:
91+
return trans_fn[0]()
92+
93+
method = tf.random.uniform([], minval=0, maxval=num_trans, dtype=tf.int32)
94+
if num_trans == 2:
95+
return tf.cond(tf.equal(method, 0), trans_fn[0], trans_fn[1])
7996

8097
aug_seq, aug_len = tf.cond(
8198
tf.equal(method, 0), crop_fn,
8299
lambda: tf.cond(tf.equal(method, 1), mask_fn, reorder_fn))
83-
84100
return aug_seq, aug_len
85101

86102

git-lfs/git_lfs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def get_yes_no(msg):
224224
'usage: python git_lfs.py [pull] [push] [add filename] [resolve_conflict]'
225225
)
226226
sys.exit(1)
227+
home_directory = os.path.expanduser("~")
227228
with open('.git_oss_config_pub', 'r') as fin:
228229
git_oss_data_dir = None
229230
host = None
@@ -237,7 +238,7 @@ def get_yes_no(msg):
237238
continue
238239
if line_str.startswith('#'):
239240
continue
240-
line_str = line_str.replace('~/', os.environ['HOME'] + '/')
241+
line_str = line_str.replace('~/', home_directory + '/')
241242
line_str = line_str.replace('${TMPDIR}/',
242243
os.environ.get('TMPDIR', '/tmp/'))
243244
line_str = line_str.replace('${PROJECT_NAME}', get_proj_name())
@@ -251,7 +252,7 @@ def get_yes_no(msg):
251252
elif line_tok[0] == 'git_oss_private_config':
252253
git_oss_private_path = line_tok[1]
253254
if git_oss_private_path.startswith('~/'):
254-
git_oss_private_path = os.path.join(os.environ['HOME'],
255+
git_oss_private_path = os.path.join(home_directory,
255256
git_oss_private_path[2:])
256257
elif line_tok[0] == 'git_oss_cache_dir':
257258
git_oss_cache_dir = line_tok[1]

0 commit comments

Comments
 (0)