Skip to content

Commit db91fc4

Browse files
committed
bool arguments fix
1 parent 2a6db66 commit db91fc4

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

utils/tranform_functions.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,34 @@ def msmarco_answerability_detection_to_tsv(dataDir, readFile, wrtDir, transParam
671671

672672
devDf.to_csv(os.path.join(wrtDir, 'msmarco_answerability_test.tsv'), sep='\t', index=False, header=False)
673673
print('Test file written at: ', os.path.join(wrtDir, 'msmarco_answerability_test.tsv'))
674-
674+
675+
def query_correctness_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
676+
677+
"""
678+
- Query correctness transformed file
679+
For using this transform function, set ``transform_func`` : **query_correctness_to_tsv** in transform file.
680+
Args:
681+
dataDir (:obj:`str`) : Path to the directory where the raw data files to be read are present..
682+
readFile (:obj:`str`) : This is the file which is currently being read and transformed by the function.
683+
wrtDir (:obj:`str`) : Path to the directory where to save the transformed tsv files.
684+
transParamDict (:obj:`dict`, defaults to :obj:`None`): Dictionary of function specific parameters. Not required for this transformation function.
685+
"""
686+
print('Making data from file {}'.format(readFile))
687+
df = pd.read_csv(os.path.join(dataDir, readFile), sep='\t', header=None, names = ['query', 'label'])
688+
689+
# we consider anything above 0.6 as structured query (3 or more annotations as structured), and others as non-structured
690+
691+
#df['label'] = [str(lab) for lab in df['label']]
692+
df['label'] = [int(lab>=0.6)for lab in df['label']]
693+
694+
data = [ [str(i), str(row['label']), row['query'] ] for i, row in df.iterrows()]
695+
696+
wrtDf = pd.DataFrame(data, columns = ['uid', 'label', 'query'])
697+
698+
#writing
699+
wrtDf.to_csv(os.path.join(wrtDir, 'query_correctness_{}'.format(readFile)), sep="\t", index=False, header=False)
700+
print('File saved at: ', os.path.join(wrtDir, 'query_correctness_{}'.format(readFile)))
701+
675702
def clinc_out_of_scope_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
676703

677704
"""

0 commit comments

Comments
 (0)