Skip to content

Commit 9b2adb9

Browse files
ssdf93ljshou
authored andcommitted
Fix concat in interaction layer. (#57)
* Add Support for MatchPyramid model * Modified interaction layer * change 2d to 2D * Flatten conflicts solved. * fix conv2d and pooling2d * add MatchPyramid score report * add MatchPyramid score record * Fix concat in interaction layer. * merge * Delete Conv2d.py * Delete Pooling2d.py * Delete test.json * Update .gitignore
1 parent a1041f4 commit 9b2adb9

File tree

3 files changed

+4
-5
lines changed

3 files changed

+4
-5
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
*.vs*
66
dataset/GloVe/
77
dataset/20_newsgroups/
8-
models/
8+
models/

block_zoo/attentions/Interaction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def inference(self):
5858
def verify(self):
5959
super(InteractionConf, self).verify()
6060
assert hasattr(self, 'matching_type'), "Please define matching_type attribute of BiGRUConf in default() or the configuration file"
61-
assert self.matching_type in ['general', 'dot', 'mul', 'plus', 'minus', 'add'], "Invalid `matching_type`{self.matching_type} received. Must be in `mul`, `general`, `plus`, `minus`, `dot` and `concat`."
61+
assert self.matching_type in ['general', 'dot', 'mul', 'plus', 'minus', 'add', 'concat'], "Invalid `matching_type`{self.matching_type} received. Must be in `mul`, `general`, `plus`, `minus`, `dot` and `concat`."
6262

6363

6464
class Interaction(BaseLayer):
@@ -120,7 +120,7 @@ def func(x, y):
120120
return x - y
121121
elif self.matching_type == 'concat':
122122
def func(x, y):
123-
return torch.concat([x, y], axis=-1)
123+
return torch.cat([x, y], dim=-1)
124124
else:
125125
raise ValueError(f"Invalid matching type."
126126
f"{self.matching_type} received."

model_zoo/nlp_tasks/question_answer_matching/conf_question_answer_matching_pyramid.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
"valid_times_per_epoch": 5,
5858
"fixed_lengths":{
5959
"question": 30,
60-
"passage": 120
60+
"passage": 200
6161
}
6262
},
6363
"architecture":[
@@ -92,7 +92,6 @@
9292
"layer": "Interaction",
9393
"conf": {
9494
"dropout": 0.2,
95-
"hidden_dim": 300,
9695
"matching_type": "general"
9796
},
9897
"inputs": ["question_dropout", "passage_dropout"]

0 commit comments

Comments
 (0)