@@ -339,6 +339,95 @@ def generate_ngram_sequences(data, seq_len_right, seq_len_left):
339339 sequence_dict [key ] = left_seq + right_seq
340340 i += 1
341341 return sequence_dict
342+
343+ def validate_sequences (sequence_dict , seq_len_right , seq_len_left ):
344+ micro_sequences = []
345+ macro_sequences = {}
346+
347+ for key in sequence_dict .keys ():
348+ score = sequence_dict [key ]
349+
350+ if score < 1 and len (key .split ()) <= seq_len_right :
351+ micro_sequences .append (key )
352+ else :
353+ macro_sequences [key ] = score
354+
355+ non_frag_sequences = []
356+ macro_sequences_copy = macro_sequences .copy ()
357+
358+ for sent in tqdm (micro_sequences , total = len (micro_sequences )):
359+ for key in macro_sequences .keys ():
360+ if sent in key :
361+ non_frag_sequences .append (key )
362+ del macro_sequences_copy [key ]
363+
364+ macro_sequences = macro_sequences_copy .copy ()
365+
366+ for sent in non_frag_sequences :
367+ macro_sequences [sent ] = 0
368+
369+ for sent in micro_sequences :
370+ macro_sequences [sent ] = 0
371+
372+ return macro_sequences
373+
374+ def create_fragment_detection_tsv (dataDir , readFile , wrtDir , transParamDict , isTrainFile = False ):
375+
376+ """
377+ This function transforms data for fragment detection task (detecting whether a sentence is incomplete/fragment or not).
378+ It takes data in single sentence classification format and creates fragment samples from the sentences.
379+ In the transformed file, label 1 and 0 represent fragment and non-fragment sentence respectively.
380+ Following transformed files are written at wrtDir
381+
382+ - Fragment transformed tsv file containing fragment/non-fragment sentences and labels
383+
384+
385+ For using this transform function, set ``transform_func`` : **create_fragment_detection_tsv** in transform file.
386+ Args:
387+ dataDir (:obj:`str`) : Path to the directory where the raw data files to be read are present..
388+ readFile (:obj:`str`) : This is the file which is currently being read and transformed by the function.
389+ wrtDir (:obj:`str`) : Path to the directory where to save the transformed tsv files.
390+ transParamDict (:obj:`dict`, defaults to :obj:`None`): Dictionary requiring the following parameters as key-value
391+
392+ - ``data_frac`` (defaults to 0.2) : Fraction of data to consider for making fragments.
393+ - ``seq_len_right`` : (defaults to 3) : Right window length for making n-grams.
394+ - ``seq_len_left`` (defaults to 2) : Left window length for making n-grams.
395+ - ``sep`` (defaults to "\t ") : column separator for input file.
396+ - ``query_col`` (defaults to 2) : column number containing sentences. Counting starts from 0.
397+
398+ """
399+
400+ transParamDict .setdefault ("data_frac" , 0.2 )
401+ transParamDict .setdefault ("seq_len_right" , 3 )
402+ transParamDict .setdefault ("seq_len_left" , 2 )
403+ transParamDict .setdefault ("sep" , "\t " )
404+ transParamDict .setdefault ("query_col" , 2 )
405+
406+ allDataDf = pd .read_csv (os .path .join (dataDir , readFile ), sep = transParamDict ["sep" ], header = None )
407+ sampledDataDf = allDataDf .sample (frac = float (transParamDict ['data_frac' ]), random_state = 42 )
408+
409+ #2nd column is considered to have queries in dataframe, 0th uid, 1st label
410+ # making n-gram with left and right window
411+ seqDict = generate_ngram_sequences (data = list (sampledDataDf .iloc [:, int (transParamDict ["query_col" ])]),
412+ seq_len_right = transParamDict ['seq_len_right' ],
413+ seq_len_left = transParamDict ['seq_len_left' ])
414+
415+ fragDict = validate_sequences (seqDict , seq_len_right = transParamDict ['seq_len_right' ],
416+ seq_len_left = transParamDict ['seq_len_left' ])
417+
418+ finalDf = pd .DataFrame ({'uid' : [i for i in range (len (fragDict )+ len (allDataDf ))],
419+ 'label' : [1 ]* len (fragDict )+ [0 ]* len (allDataDf ),
420+ 'query' : list (fragDict .keys ())+ list (allDataDf .iloc [:, int (transParamDict ["query_col" ]) ]) })
421+
422+ print ('number of fragment samples : ' , len (fragDict ))
423+ print ('number of non-fragment samples : ' , len (allDataDf ))
424+ # saving
425+ print ('writing fragment file for {} at {}' .format (readFile , wrtDir ))
426+
427+ finalDf .to_csv (os .path .join (wrtDir , 'fragment_{}.tsv' .format (readFile .split ('.' )[0 ])), sep = '\t ' ,
428+ index = False , header = False )
429+
430+
342431def msmarco_answerability_detection_to_tsv (dataDir , readFile , wrtDir , transParamDict , isTrainFile = False ):
343432 """
344433 This function transforms the MSMARCO triples data available at `triples <https://msmarco.blob.core.windows.net/msmarcoranking/triples.train.small.tar.gz>`_
@@ -412,7 +501,7 @@ def msmarco_answerability_detection_to_tsv(dataDir, readFile, wrtDir, transParam
412501
413502 devDf .to_csv (os .path .join (wrtDir , 'msmarco_answerability_test.tsv' ), sep = '\t ' , index = False , header = False )
414503 print ('Test file written at: ' , os .path .join (wrtDir , 'msmarco_answerability_test.tsv' ))
415-
504+
416505def msmarco_query_type_to_tsv (dataDir , readFile , wrtDir , transParamDict , isTrainFile = False ):
417506
418507 """
@@ -458,6 +547,7 @@ def msmarco_query_type_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrain
458547 labelMapPath = os .path .join (wrtDir , 'querytype_{}_label_map.joblib' .format (readFile .lower ().replace ('.json' , '' )))
459548 joblib .dump (labelMap , labelMapPath )
460549 print ('Created label map file at' , labelMapPath )
550+
461551
462552def imdb_sentiment_data_to_tsv (dataDir , readFile , wrtDir , transParamDict , isTrainFile = False ):
463553
0 commit comments