Skip to content

Commit 4f608c5

Browse files
committed
transform function fix
1 parent c4c1df3 commit 4f608c5

File tree

1 file changed

+0
-159
lines changed

1 file changed

+0
-159
lines changed

utils/tranform_functions.py

Lines changed: 0 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -340,93 +340,6 @@ def generate_ngram_sequences(data, seq_len_right, seq_len_left):
340340
i += 1
341341
return sequence_dict
342342

343-
def validate_sequences(sequence_dict, seq_len_right, seq_len_left):
344-
micro_sequences = []
345-
macro_sequences = {}
346-
347-
for key in sequence_dict.keys():
348-
score = sequence_dict[key]
349-
350-
if score < 1 and len(key.split()) <= seq_len_right:
351-
micro_sequences.append(key)
352-
else:
353-
macro_sequences[key] = score
354-
355-
non_frag_sequences = []
356-
macro_sequences_copy = macro_sequences.copy()
357-
358-
for sent in tqdm(micro_sequences, total = len(micro_sequences)):
359-
for key in macro_sequences.keys():
360-
if sent in key:
361-
non_frag_sequences.append(key)
362-
del macro_sequences_copy[key]
363-
364-
macro_sequences = macro_sequences_copy.copy()
365-
366-
for sent in non_frag_sequences:
367-
macro_sequences[sent] = 0
368-
369-
for sent in micro_sequences:
370-
macro_sequences[sent] = 0
371-
372-
return macro_sequences
373-
374-
def create_fragment_detection_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
375-
376-
"""
377-
This function transforms data for fragment detection task (detecting whether a sentence is incomplete/fragment or not).
378-
It takes data in single sentence classification format and creates fragment samples from the sentences.
379-
In the transformed file, label 1 and 0 represent fragment and non-fragment sentence respectively.
380-
Following transformed files are written at wrtDir
381-
382-
- Fragment transformed tsv file containing fragment/non-fragment sentences and labels
383-
384-
385-
For using this transform function, set ``transform_func`` : **create_fragment_detection_tsv** in transform file.
386-
Args:
387-
dataDir (:obj:`str`) : Path to the directory where the raw data files to be read are present..
388-
readFile (:obj:`str`) : This is the file which is currently being read and transformed by the function.
389-
wrtDir (:obj:`str`) : Path to the directory where to save the transformed tsv files.
390-
transParamDict (:obj:`dict`, defaults to :obj:`None`): Dictionary requiring the following parameters as key-value
391-
392-
- ``data_frac`` (defaults to 0.2) : Fraction of data to consider for making fragments.
393-
- ``seq_len_right`` : (defaults to 3) : Right window length for making n-grams.
394-
- ``seq_len_left`` (defaults to 2) : Left window length for making n-grams.
395-
- ``sep`` (defaults to "\t") : column separator for input file.
396-
- ``query_col`` (defaults to 2) : column number containing sentences. Counting starts from 0.
397-
398-
"""
399-
400-
transParamDict.setdefault("data_frac", 0.2)
401-
transParamDict.setdefault("seq_len_right", 3)
402-
transParamDict.setdefault("seq_len_left", 2)
403-
transParamDict.setdefault("sep", "\t")
404-
transParamDict.setdefault("query_col", 2)
405-
406-
allDataDf = pd.read_csv(os.path.join(dataDir, readFile), sep=transParamDict["sep"], header=None)
407-
sampledDataDf = allDataDf.sample(frac = float(transParamDict['data_frac']), random_state=42)
408-
409-
#2nd column is considered to have queries in dataframe, 0th uid, 1st label
410-
# making n-gram with left and right window
411-
seqDict = generate_ngram_sequences(data = list(sampledDataDf.iloc[:, int(transParamDict["query_col"])]),
412-
seq_len_right = transParamDict['seq_len_right'],
413-
seq_len_left = transParamDict['seq_len_left'])
414-
415-
fragDict = validate_sequences(seqDict, seq_len_right = transParamDict['seq_len_right'],
416-
seq_len_left = transParamDict['seq_len_left'])
417-
418-
finalDf = pd.DataFrame({'uid' : [i for i in range(len(fragDict)+len(allDataDf))],
419-
'label' : [1]*len(fragDict)+[0]*len(allDataDf),
420-
'query' : list(fragDict.keys())+list(allDataDf.iloc[:, int(transParamDict["query_col"]) ]) })
421-
422-
print('number of fragment samples : ', len(fragDict))
423-
print('number of non-fragment samples : ', len(allDataDf))
424-
# saving
425-
print('writing fragment file for {} at {}'.format(readFile, wrtDir))
426-
427-
finalDf.to_csv(os.path.join(wrtDir, 'fragment_{}.tsv'.format(readFile.split('.')[0])), sep='\t',
428-
index=False, header=False)
429-
430343
def msmarco_query_type_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
431344

432345
"""
@@ -573,79 +486,7 @@ def qqp_query_similarity_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTra
573486
index=False, header=False)
574487
print('Test file saved at: {}'.format(os.path.join(wrtDir, 'qqp_query_similarity_test.tsv')))
575488

576-
def msmarco_answerability_detection_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
577-
"""
578-
This function transforms the MSMARCO triples data available at `triples <https://msmarco.blob.core.windows.net/msmarcoranking/triples.train.small.tar.gz>`_
579-
580-
The data contains triplets where the first entry is the query, second one is the context passage from which the query can be
581-
answered (positive passage) , while the third entry is a context passage from which the query cannot be answered (negative passage).
582-
Data is transformed into sentence pair classification format, with query-positive context pair labeled as 1 (answerable)
583-
and query-negative context pair labeled as 0 (non-answerable)
584-
585-
Following transformed files are written at wrtDir
586-
587-
- Sentence pair transformed downsampled file.
588-
- Sentence pair transformed train tsv file for answerability task
589-
- Sentence pair transformed dev tsv file for answerability task
590-
- Sentence pair transformed test tsv file for answerability task
591-
592-
For using this transform function, set ``transform_func`` : **msmarco_answerability_detection_to_tsv** in transform file.
593489

594-
Args:
595-
dataDir (:obj:`str`) : Path to the directory where the raw data files to be read are present..
596-
readFile (:obj:`str`) : This is the file which is currently being read and transformed by the function.
597-
wrtDir (:obj:`str`) : Path to the directory where to save the transformed tsv files.
598-
transParamDict (:obj:`dict`, defaults to :obj:`None`): Dictionary of function specific parameters. Not required for this transformation function.
599-
600-
- ``data_frac`` (defaults to 0.01) : Fraction of data to keep in downsampling as the original data size is too large.
601-
"""
602-
transParamDict.setdefault("data_frac", 0.01)
603-
sampleEvery = int(1/float(transParamDict["data_frac"]))
604-
startId = 0
605-
print('Making data from file {} ....'.format(readFile))
606-
rf = open(os.path.join(dataDir, readFile))
607-
sf = open(os.path.join(wrtDir, 'msmarco_triples_sampled.tsv'), 'w')
608-
609-
# reading the big file line by line
610-
for i, row in enumerate(rf):
611-
# sampling
612-
if i % 100000 == 0:
613-
print("Processing {} rows...".format(i))
614-
615-
if i % sampleEvery == 0:
616-
rowData = row.split('\t')
617-
posRowData = str(startId)+'\t'+str(1)+'\t'+ rowData[0]+'\t'+rowData[1]
618-
negRowData = str(startId+1)+'\t'+str(0)+'\t'+ rowData[0]+'\t'+rowData[2].rstrip('\n')
619-
620-
#AN IMPORTANT POINT HERE IS TO STRIP THE row ending '\n' present after the negative
621-
# passage, otherwise it will hamper the dataframe.
622-
623-
#print(negRowData)
624-
# writing the positive and negative into new sampled data file
625-
sf.write(posRowData+'\n')
626-
sf.write(negRowData+'\n')
627-
628-
#increasing id count
629-
startId += 2
630-
print('Total Number of rows in original data: ', i)
631-
print('Number of answerable samples in downsampled data: ', int(startId / 2))
632-
print('Number of non-answerable samples in downsampled data: ', int(startId / 2))
633-
print('Downsampled msmarco triples tsv saved at: {}'.format(os.path.join(wrtDir, 'msmarco_triples_sampled.tsv')))
634-
635-
#making train, test, dev split
636-
sampledDf = pd.read_csv(os.path.join(wrtDir, 'msmarco_triples_sampled.tsv'), sep='\t', header=None)
637-
trainDf, testDf = train_test_split(sampledDf, shuffle=True, random_state=SEED,
638-
test_size=0.02)
639-
trainDf.to_csv(os.path.join(wrtDir, 'msmarco_answerability_train.tsv'), sep='\t', index=False, header=False)
640-
print('Train file written at: ', os.path.join(wrtDir, 'msmarco_answerability_train.tsv'))
641-
642-
devDf, testDf = train_test_split(testDf, shuffle=True, random_state=SEED,
643-
test_size=0.5)
644-
devDf.to_csv(os.path.join(wrtDir, 'msmarco_answerability_dev.tsv'), sep='\t', index=False, header=False)
645-
print('Dev file written at: ', os.path.join(wrtDir, 'msmarco_answerability_dev.tsv'))
646-
647-
devDf.to_csv(os.path.join(wrtDir, 'msmarco_answerability_test.tsv'), sep='\t', index=False, header=False)
648-
print('Test file written at: ', os.path.join(wrtDir, 'msmarco_answerability_test.tsv'))
649490

650491
def query_correctness_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
651492

0 commit comments

Comments
 (0)