@@ -54,6 +54,7 @@ def main(args = None):
5454 parser .add_argument ("sem" )
5555 parser .add_argument ("--arbi-types" , action = "store_true" , default = False )
5656 parser .add_argument ("--gold_trees" , action = "store_true" , default = True )
57+ parser .add_argument ("--nbest" , nargs = '?' , type = int , default = "1" )
5758 args = parser .parse_args ()
5859
5960 if not os .path .exists (args .templates ):
@@ -71,27 +72,35 @@ def main(args = None):
7172 root = etree .parse (args .ccg , parser )
7273
7374 for sentence in root .findall ('.//sentence' ):
74- sem_node = etree .Element ('semantics' )
75- try :
76- sem_node .set ('status' , 'success' )
77- tree_index = 1
78- if args .gold_trees :
79- tree_index = int (sentence .get ('gold_tree' , '0' )) + 1
80- sem_tree = assign_semantics_to_ccg (
81- sentence , semantic_index , tree_index )
82- sem_node .set ('root' ,
83- sentence .xpath ('./ccg[{0}]/@root' .format (tree_index ))[0 ])
84- filter_attributes (sem_tree )
85- sem_node .extend (sem_tree .xpath ('.//descendant-or-self::span' ))
86- except LogicalExpressionException as e :
87- sem_node .set ('status' , 'failed' )
88- logging .error ('An error occurred: {0}' .format (e ))
89- sentence .append (sem_node )
75+ if args .gold_trees :
76+ tree_indices = [int (sentence .get ('gold_tree' , '0' )) + 1 ]
77+ if args .nbest != 1 :
78+ tree_indices = get_tree_indices (sentence , args .nbest )
79+ for tree_index in tree_indices :
80+ sem_node = etree .Element ('semantics' )
81+ sem_node .set ('ccg_id' ,
82+ sentence .xpath ('./ccg[{0}]/@id' .format (tree_index ))[0 ])
83+ try :
84+ sem_node .set ('status' , 'success' )
85+ sem_tree = assign_semantics_to_ccg (
86+ sentence , semantic_index , tree_index )
87+ sem_node .set ('root' ,
88+ sentence .xpath ('./ccg[{0}]/@root' .format (tree_index ))[0 ])
89+ filter_attributes (sem_tree )
90+ sem_node .extend (sem_tree .xpath ('.//descendant-or-self::span' ))
91+ except LogicalExpressionException as e :
92+ sem_node .set ('status' , 'failed' )
93+ logging .error ('An error occurred: {0}' .format (e ))
94+ sentence .append (sem_node )
9095
9196 root_xml_str = serialize_tree (root )
9297 with codecs .open (args .sem , 'wb' ) as fout :
9398 fout .write (root_xml_str )
9499
100+ def get_tree_indices (sentence , nbest ):
101+ num_ccg_trees = int (sentence .xpath ('count(./ccg)' ))
102+ return list (range (1 , min (nbest , num_ccg_trees ) + 1 ))
103+
95104keep_attributes = set (['id' , 'child' , 'sem' , 'type' ])
96105def filter_attributes (tree ):
97106 if 'coq_type' in tree .attrib and 'child' not in tree .attrib :
0 commit comments