@@ -671,7 +671,34 @@ def msmarco_answerability_detection_to_tsv(dataDir, readFile, wrtDir, transParam
671671
672672 devDf .to_csv (os .path .join (wrtDir , 'msmarco_answerability_test.tsv' ), sep = '\t ' , index = False , header = False )
673673 print ('Test file written at: ' , os .path .join (wrtDir , 'msmarco_answerability_test.tsv' ))
674-
674+
675+ def query_correctness_to_tsv (dataDir , readFile , wrtDir , transParamDict , isTrainFile = False ):
676+
677+ """
678+ - Query correctness transformed file
679+ For using this transform function, set ``transform_func`` : **query_correctness_to_tsv** in transform file.
680+ Args:
681+ dataDir (:obj:`str`) : Path to the directory where the raw data files to be read are present..
682+ readFile (:obj:`str`) : This is the file which is currently being read and transformed by the function.
683+ wrtDir (:obj:`str`) : Path to the directory where to save the transformed tsv files.
684+ transParamDict (:obj:`dict`, defaults to :obj:`None`): Dictionary of function specific parameters. Not required for this transformation function.
685+ """
686+ print ('Making data from file {}' .format (readFile ))
687+ df = pd .read_csv (os .path .join (dataDir , readFile ), sep = '\t ' , header = None , names = ['query' , 'label' ])
688+
689+ # we consider anything above 0.6 as structured query (3 or more annotations as structured), and others as non-structured
690+
691+ #df['label'] = [str(lab) for lab in df['label']]
692+ df ['label' ] = [int (lab >= 0.6 )for lab in df ['label' ]]
693+
694+ data = [ [str (i ), str (row ['label' ]), row ['query' ] ] for i , row in df .iterrows ()]
695+
696+ wrtDf = pd .DataFrame (data , columns = ['uid' , 'label' , 'query' ])
697+
698+ #writing
699+ wrtDf .to_csv (os .path .join (wrtDir , 'query_correctness_{}' .format (readFile )), sep = "\t " , index = False , header = False )
700+ print ('File saved at: ' , os .path .join (wrtDir , 'query_correctness_{}' .format (readFile )))
701+
675702def clinc_out_of_scope_to_tsv (dataDir , readFile , wrtDir , transParamDict , isTrainFile = False ):
676703
677704 """
0 commit comments