@@ -46,9 +46,10 @@ public class OrNode extends Node<OrNode, OrActivation> {
4646
4747 private static final Logger log = LoggerFactory .getLogger (OrNode .class );
4848
49- public TreeSet <OrEntry > andParents = new TreeSet <>();
49+ TreeSet <OrEntry > andParents = new TreeSet <>();
50+
51+ private Neuron outputNeuron = null ;
5052
51- public Neuron neuron = null ;
5253
5354 public OrNode () {}
5455
@@ -66,45 +67,22 @@ public AndNode.RefValue extend(int threadId, Document doc, AndNode.Refinement re
6667
6768 public void addInputActivation (OrEntry oe , NodeActivation inputAct ) {
6869 Document doc = inputAct .getDocument ();
70+ INeuron n = outputNeuron .get (doc );
6971
70- SortedMap <Integer , Position > slots = new TreeMap <>();
71-
72- INeuron n = neuron .get (doc );
73- for (int i = 0 ; i < oe .synapseIds .length ; i ++) {
74- int synapseId = oe .synapseIds [i ];
75-
76- Synapse s = neuron .getSynapseById (synapseId );
77- for (Map .Entry <Integer , Relation > me : s .getRelations ().entrySet ()) {
78- Relation rel = me .getValue ();
79- if (me .getKey () == Synapse .OUTPUT ) {
80- Activation iAct = inputAct .getInputActivation (i );
81- rel .mapSlots (slots , iAct );
82- }
83- }
84- }
85-
86- for (Integer slot : n .slotRequired ) {
87- if (!slots .containsKey (slot )) {
88- if (!n .slotHasInputs .contains (slot )) {
89- slots .put (slot , new Position (doc ));
90- } else {
91- return ;
92- }
93- }
94- }
72+ SortedMap <Integer , Position > slots = getSlots (oe , inputAct );
73+ if (n .checkRequiredSlots (doc , slots )) return ;
9574
9675 Activation act = lookupActivation (doc , slots , oe , inputAct );
9776
9877 if (act == null ) {
9978 OrActivation orAct = new OrActivation (doc , this );
100- act = new Activation (doc , neuron .get (doc ));
79+
80+ act = new Activation (doc , outputNeuron .get (doc ), slots );
81+
10182 act .setInputNodeActivation (orAct );
10283 orAct .setOutputAct (act );
10384
104- act .setSlots (slots );
105-
10685 processActivation (orAct );
107- neuron .get (act .getDocument ()).register (act );
10886 } else {
10987 propagate (act .getInputNodeActivation ());
11088 }
@@ -114,8 +92,26 @@ public void addInputActivation(OrEntry oe, NodeActivation inputAct) {
11492 }
11593
11694
95+ public SortedMap <Integer , Position > getSlots (OrEntry oe , NodeActivation inputAct ) {
96+ SortedMap <Integer , Position > slots = new TreeMap <>();
97+ for (int i = 0 ; i < oe .synapseIds .length ; i ++) {
98+ int synapseId = oe .synapseIds [i ];
99+
100+ Synapse s = outputNeuron .getSynapseById (synapseId );
101+ for (Map .Entry <Integer , Relation > me : s .getRelations ().entrySet ()) {
102+ Relation rel = me .getValue ();
103+ if (me .getKey () == Synapse .OUTPUT ) {
104+ Activation iAct = inputAct .getInputActivation (i );
105+ rel .mapSlots (slots , iAct );
106+ }
107+ }
108+ }
109+ return slots ;
110+ }
111+
112+
117113 private Activation lookupActivation (Document doc , SortedMap <Integer , Position > slots , OrEntry oe , NodeActivation inputAct ) {
118- x : for (Activation act : neuron .get (doc )
114+ x : for (Activation act : outputNeuron .get (doc )
119115 .getActivations (doc , slots )
120116 ) {
121117
@@ -168,7 +164,7 @@ public void apply(OrActivation act) {
168164 }
169165
170166
171- public static void processCandidate (Node <?, ? extends NodeActivation <?>> parentNode , NodeActivation inputAct , boolean train ) {
167+ static void processCandidate (Node <?, ? extends NodeActivation <?>> parentNode , NodeActivation inputAct , boolean train ) {
172168 Document doc = inputAct .getDocument ();
173169 try {
174170 parentNode .lock .acquireReadLock ();
@@ -195,7 +191,7 @@ public void reprocessInputs(Document doc) {
195191 }
196192
197193
198- public void addInput (int [] synapseIds , int threadId , Node in , boolean andMode ) {
194+ void addInput (int [] synapseIds , int threadId , Node in , boolean andMode ) {
199195 in .changeNumberOfNeuronRefs (threadId , provider .model .visitedCounter .addAndGet (1 ), 1 );
200196
201197 OrEntry oe = new OrEntry (synapseIds , in .getProvider (), provider );
@@ -212,7 +208,7 @@ public void addInput(int[] synapseIds, int threadId, Node in, boolean andMode) {
212208
213209
214210 void remove (int threadId ) {
215- neuron .get ().remove ();
211+ outputNeuron .get ().remove ();
216212
217213 super .remove ();
218214
@@ -225,7 +221,7 @@ void remove(int threadId) {
225221 }
226222
227223
228- public void removeParents (int threadId ) {
224+ void removeParents (int threadId ) {
229225 for (OrEntry oe : andParents ) {
230226 Node pn = oe .parent .get ();
231227 pn .changeNumberOfNeuronRefs (threadId , provider .model .visitedCounter .addAndGet (1 ), -1 );
@@ -237,7 +233,7 @@ public void removeParents(int threadId) {
237233
238234
239235 @ Override
240- public void changeNumberOfNeuronRefs (int threadId , long v , int d ) {
236+ protected void changeNumberOfNeuronRefs (int threadId , long v , int d ) {
241237 throw new UnsupportedOperationException ();
242238 }
243239
@@ -272,7 +268,7 @@ public void write(DataOutput out) throws IOException {
272268 out .writeChar ('O' );
273269 super .write (out );
274270
275- out .writeInt (neuron .id );
271+ out .writeInt (outputNeuron .id );
276272
277273 out .writeInt (andParents .size ());
278274 for (OrEntry oe : andParents ) {
@@ -285,7 +281,7 @@ public void write(DataOutput out) throws IOException {
285281 public void readFields (DataInput in , Model m ) throws IOException {
286282 super .readFields (in , m );
287283
288- neuron = m .lookupNeuron (in .readInt ());
284+ outputNeuron = m .lookupNeuron (in .readInt ());
289285
290286 int s = in .readInt ();
291287 for (int i = 0 ; i < s ; i ++) {
@@ -295,16 +291,25 @@ public void readFields(DataInput in, Model m) throws IOException {
295291
296292
297293 public String getNeuronLabel () {
298- String l = neuron .getLabel ();
294+ String l = outputNeuron .getLabel ();
299295 return l != null ? l : "" ;
300296 }
301297
302298
303- public static class OrEntry implements Comparable <OrEntry >, Writable {
304- public int [] synapseIds ;
305- public TreeMap <Integer , Integer > revSynapseIds = new TreeMap <>();
306- public Provider <? extends Node > parent ;
307- public Provider <OrNode > child ;
299+ public void setOutputNeuron (Neuron n ) {
300+ outputNeuron = n ;
301+ }
302+
303+ public Neuron getOutputNeuron () {
304+ return outputNeuron ;
305+ }
306+
307+
308+ static class OrEntry implements Comparable <OrEntry >, Writable {
309+ int [] synapseIds ;
310+ TreeMap <Integer , Integer > revSynapseIds = new TreeMap <>();
311+ Provider <? extends Node > parent ;
312+ Provider <OrNode > child ;
308313
309314 private OrEntry () {}
310315
@@ -318,6 +323,7 @@ public OrEntry(int[] synapseIds, Provider<? extends Node> parent, Provider<OrNod
318323 this .child = child ;
319324 }
320325
326+
321327 @ Override
322328 public void write (DataOutput out ) throws IOException {
323329 out .writeInt (synapseIds .length );
@@ -413,5 +419,14 @@ public Link(OrEntry oe, NodeActivation<?> input, OrActivation output) {
413419 this .input = input ;
414420 this .output = output ;
415421 }
422+
423+
424+ public int size () {
425+ return oe .synapseIds .length ;
426+ }
427+
428+ public int get (int i ) {
429+ return oe .synapseIds [i ];
430+ }
416431 }
417432}
0 commit comments