Skip to content

Commit 3c07346

Browse files
committed
split file name fix
1 parent 4c8d04f commit 3c07346

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

data_preparation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,8 @@ def main():
303303
print('Loading raw data for task {} from {}'.format(taskName, os.path.join(args.data_dir, file)))
304304
rows = load_data(os.path.join(args.data_dir, file), tasks.taskTypeMap[taskName],
305305
hasLabels = args.has_labels)
306-
wrtFile = os.path.join(dataPath, '{}.json'.format(file.split('.')[0]))
306+
#wrtFile = os.path.join(dataPath, '{}.json'.format(file.split('.')[0]))
307+
wrtFile = os.path.join(dataPath, '{}.json'.format(file.lower().replace('.tsv', '')))
307308
print('Processing Started...')
308309
create_data_multithreaded(rows, wrtFile, tokenizer, tasks, taskName,
309310
args.max_seq_len, args.multithreaded)

examples/intent_ner_fragment/tasks_file_snips.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ fragdetect:
3939
loss_type: CrossEntropyLoss
4040
task_type: SingleSenClassification
4141
file_names:
42-
- fragment_snips_train.tsv
43-
- fragment_snips_dev.tsv
44-
- fragment_snips_test.tsv
42+
- fragment_intent_snips_train.tsv
43+
- fragment_intent_snips_dev.tsv
44+
- fragment_intent_snips_test.tsv

train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def make_data_handlers(taskParams, mode, isTrain, gpu):
119119
taskType = taskParams.taskTypeMap[taskName]
120120
if mode == "test":
121121
assert len(taskParams.fileNamesMap[taskName])==3, "test file is required along with train, dev"
122-
dataFileName = '{}.json'.format(taskParams.fileNamesMap[taskName][modeIdx].split('.')[0])
122+
#dataFileName = '{}.json'.format(taskParams.fileNamesMap[taskName][modeIdx].split('.')[0])
123+
dataFileName = '{}.json'.format(taskParams.fileNamesMap[taskName][modeIdx].lower().replace('.tsv',''))
123124
taskDataPath = os.path.join(args.data_dir, dataFileName)
124125
assert os.path.exists(taskDataPath), "{} doesn't exist".format(taskDataPath)
125126
taskDict = {"data_task_id" : int(taskId),

0 commit comments

Comments
 (0)