|
| 1 | +package org.apache.sysds.runtime.einsum; |
| 2 | + |
| 3 | +import java.util.ArrayList; |
| 4 | +import java.util.Arrays; |
| 5 | +import java.util.HashMap; |
| 6 | +import java.util.HashSet; |
| 7 | +import java.util.List; |
| 8 | + |
| 9 | +public class EOpNodeEinsumFuse extends EOpNode { |
| 10 | + public static final int AB_index=0; |
| 11 | + public static final int BA_index=1; |
| 12 | + public static final int B_index=2; |
| 13 | + public static final int XB_index=3; |
| 14 | + public static final int BX_index=4; |
| 15 | + public static final int A_index=5; |
| 16 | + public static final int XA_index=6; |
| 17 | + public static final int AX_index=7; |
| 18 | + public static final int AZ_index=8; |
| 19 | + public enum EinsumRewriteType{ |
| 20 | + // inputops__output 'X' = simplySumDim |
| 21 | + AB_BA_B_XB_BX_A_XA_AX__AB, |
| 22 | + AB_BA_B_XB_BX_A_XA_AX__B, |
| 23 | + AB_BA_B_XB_BX_A_XA_AX__A, |
| 24 | + AB_BA_B_XB_BX_A_XA_AX__, |
| 25 | + |
| 26 | + AB_BA_B_XB_BX_A_XA_AX_AZ__Z |
| 27 | + } |
| 28 | + public enum EinsumRewriteType_v2{ // option 2 without X dims |
| 29 | + AB_BA_B_A__AB, |
| 30 | + AB_BA_B_A__B, |
| 31 | + AB_BA_B_A__A, |
| 32 | + |
| 33 | + AB_BA_B_A_AZ__Z |
| 34 | + } |
| 35 | + public final EinsumRewriteType einsumRewriteType; |
| 36 | + public final List<List<EOpNode>> operands; |
| 37 | + |
| 38 | + private EOpNodeEinsumFuse(Character c1, Character c2, EinsumRewriteType einsumRewriteType, List<EOpNode>... operands) { |
| 39 | + super(c1,c2); |
| 40 | + this.einsumRewriteType = einsumRewriteType; |
| 41 | + this.operands = Arrays.asList(operands); |
| 42 | + } |
| 43 | + |
| 44 | + public static EOpNodeEinsumFuse match(ArrayList<EOpNode> operands, Character outChar1, Character outChar2,/*, Set<Character> simplySummableChars,*/ ArrayList<EOpNode> ret, HashMap<Character, Integer> charToOccurences){ |
| 45 | + //precompute |
| 46 | + HashSet<String> matricesChars = new HashSet<>(); |
| 47 | + HashMap<String, ArrayList<EOpNode>> charsToMatrices = new HashMap<>(); |
| 48 | + HashMap<Character, Integer> charsToNumberOfOperands = new HashMap<>(); |
| 49 | + |
| 50 | + for (EOpNode operand1 : operands) { |
| 51 | + String k; |
| 52 | +//todo remove and use input map charToOccurences |
| 53 | + if (charsToNumberOfOperands.containsKey(operand1.c1)) { |
| 54 | + charsToNumberOfOperands.put(operand1.c1, charsToNumberOfOperands.get(operand1.c1) + 1); |
| 55 | + } else { |
| 56 | + charsToNumberOfOperands.put(operand1.c1, 1); |
| 57 | + } |
| 58 | + |
| 59 | + if (operand1.c2 != null) { |
| 60 | + k = operand1.c1.toString() + operand1.c2; |
| 61 | + matricesChars.add(k); |
| 62 | + if (charsToNumberOfOperands.containsKey(operand1.c2)) { |
| 63 | + charsToNumberOfOperands.put(operand1.c2, charsToNumberOfOperands.get(operand1.c2) + 1); |
| 64 | + } else { |
| 65 | + charsToNumberOfOperands.put(operand1.c2, 1); |
| 66 | + } |
| 67 | + } else { |
| 68 | + k = operand1.c1.toString(); |
| 69 | + } |
| 70 | + |
| 71 | + if (charsToMatrices.containsKey(k)) { |
| 72 | + charsToMatrices.get(k).add(operand1); |
| 73 | + } else { |
| 74 | + ArrayList<EOpNode> matrices = new ArrayList<>(); |
| 75 | + matrices.add(operand1); |
| 76 | + charsToMatrices.put(k, matrices); |
| 77 | + } |
| 78 | + } |
| 79 | + |
| 80 | + ArrayList<EOpNode> AXs = new ArrayList<>(); |
| 81 | + ArrayList<EOpNode> XAs = new ArrayList<>(); |
| 82 | + ArrayList<EOpNode> BXs = new ArrayList<>(); |
| 83 | + ArrayList<EOpNode> XBs = new ArrayList<>(); |
| 84 | + ArrayList<EOpNode> AZs = new ArrayList<>(); |
| 85 | + boolean pass = false; |
| 86 | + |
| 87 | + String AB = null; |
| 88 | + String BA = null; |
| 89 | + boolean doSumA=false; |
| 90 | + boolean doSumB=false; |
| 91 | + |
| 92 | + for (String ABcandidate : matricesChars) { |
| 93 | + char a = ABcandidate.charAt(0); |
| 94 | + char b = ABcandidate.charAt(1); |
| 95 | + BA = "" + b + a; |
| 96 | + |
| 97 | + AXs = new ArrayList<>(); |
| 98 | + XAs = new ArrayList<>(); |
| 99 | + BXs = new ArrayList<>(); |
| 100 | + XBs = new ArrayList<>(); |
| 101 | + AZs = new ArrayList<>(); |
| 102 | + |
| 103 | + pass=true; |
| 104 | + |
| 105 | + |
| 106 | + for (String chars : charsToMatrices.keySet()) { |
| 107 | + if (chars.equals(ABcandidate) || chars.equals(BA)) { |
| 108 | +// ABsCounter++; |
| 109 | + continue; |
| 110 | + } |
| 111 | + |
| 112 | + if(chars.length()==1){ |
| 113 | + if(chars.charAt(0)==a){ |
| 114 | +// AsCounter++; |
| 115 | + }else if(chars.charAt(0)==b){ |
| 116 | +// BsCounter++; |
| 117 | + } |
| 118 | + continue; |
| 119 | + //always ok |
| 120 | + }else{ |
| 121 | + if(a==chars.charAt(1) && b==chars.charAt(0)){ |
| 122 | +// ABsCounter++; |
| 123 | + //BA |
| 124 | + continue; |
| 125 | + } |
| 126 | + if(chars.charAt(0)==a){ |
| 127 | + if(charsToNumberOfOperands.get(chars.charAt(1))==1){ |
| 128 | + if(chars.charAt(1)!= outChar1 && chars.charAt(1) != outChar2) { |
| 129 | + AXs.addAll(charsToMatrices.get(chars)); |
| 130 | +// AsCounter++; |
| 131 | + continue; |
| 132 | + }else{ |
| 133 | + if(AZs.size()==0){ |
| 134 | + AZs = charsToMatrices.get(chars); |
| 135 | + continue; |
| 136 | + } |
| 137 | + pass = false; |
| 138 | + break; |
| 139 | + } |
| 140 | + }else{ |
| 141 | + //dont allow for now, in theory AZ,Z or AZ,AZ would also work, but for now do them separately |
| 142 | + pass = false; |
| 143 | + break; |
| 144 | + } |
| 145 | + } |
| 146 | + else if(chars.charAt(0)==b){ |
| 147 | + if(charsToNumberOfOperands.get(chars.charAt(1))==1){ |
| 148 | + if(chars.charAt(1)!= outChar1 && chars.charAt(1) != outChar2) { |
| 149 | + BXs.addAll(charsToMatrices.get(chars)); |
| 150 | +// BsCounter++; |
| 151 | + continue; |
| 152 | + }else{ |
| 153 | + pass = false; // no BZ, maybe experiment later |
| 154 | + break; |
| 155 | + } |
| 156 | + }else{ |
| 157 | + pass = false; |
| 158 | + break; |
| 159 | + } |
| 160 | + } |
| 161 | + else if(chars.charAt(1)==a){ |
| 162 | + if(charsToNumberOfOperands.get(chars.charAt(0))==1){ |
| 163 | + if(chars.charAt(0)!= outChar1 && chars.charAt(0) != outChar2) { |
| 164 | + XAs.addAll(charsToMatrices.get(chars)); |
| 165 | +// AsCounter++; |
| 166 | + continue; |
| 167 | + }else{ |
| 168 | + pass = false; |
| 169 | + break; |
| 170 | + } |
| 171 | + }else{ |
| 172 | + pass = false; |
| 173 | + break; |
| 174 | + } |
| 175 | + } |
| 176 | + else if(chars.charAt(1)==b){ |
| 177 | + if(charsToNumberOfOperands.get(chars.charAt(0))==1){ |
| 178 | + if(chars.charAt(0)!= outChar1 && chars.charAt(0) != outChar2) { |
| 179 | + XBs.addAll(charsToMatrices.get(chars)); |
| 180 | +// BsCounter++; |
| 181 | + continue; |
| 182 | + }else{ |
| 183 | + pass = false; |
| 184 | + break; |
| 185 | + } |
| 186 | + }else{ |
| 187 | + pass = false; |
| 188 | + break; |
| 189 | + } |
| 190 | + } |
| 191 | + } |
| 192 | + } |
| 193 | + if(pass){ |
| 194 | + AB = ABcandidate; |
| 195 | + String A = ""+a; |
| 196 | + String B = ""+b; |
| 197 | + int ABsCounter = charsToMatrices.get(ABcandidate).size()+(charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA).size() : 0); |
| 198 | + int AsCounter = (charsToMatrices.containsKey(A) ? charsToMatrices.get(A).size() : 0) +AXs.size()+XAs.size(); |
| 199 | + int BsCounter = (charsToMatrices.containsKey(B) ? charsToMatrices.get(B).size() : 0)+BXs.size()+XBs.size(); |
| 200 | + if(AsCounter==0 && BsCounter==0 && ABsCounter<2){ |
| 201 | + pass=false; |
| 202 | + continue; |
| 203 | + } |
| 204 | + int usedAsCount = AsCounter+ABsCounter; |
| 205 | + int usedBsCount = BsCounter+ABsCounter; |
| 206 | + doSumA = charToOccurences.get(a)==usedAsCount && (outChar1 == null || a!=outChar1) && (outChar2 == null || a!=outChar2); |
| 207 | + doSumB = charToOccurences.get(b)==usedBsCount && (outChar1 == null || b!=outChar1) && (outChar2 == null || b!=outChar2); |
| 208 | + if(AZs.size()!=0) { // invalidate AZ fusion |
| 209 | + if (outChar1 != null) { |
| 210 | + if (a == outChar1 || b == outChar1) { |
| 211 | + pass=false; |
| 212 | + continue; |
| 213 | + } |
| 214 | + } |
| 215 | + if (outChar2 != null) { |
| 216 | + if (a == outChar2 || b == outChar2) { |
| 217 | + pass=false; |
| 218 | + continue; |
| 219 | + } |
| 220 | + } |
| 221 | + if(!doSumA || !doSumB){ |
| 222 | + pass=false; |
| 223 | + continue; |
| 224 | + } |
| 225 | + } |
| 226 | + break; |
| 227 | + } |
| 228 | + } |
| 229 | + |
| 230 | + if(!pass){ |
| 231 | + return null; |
| 232 | + } |
| 233 | + String B = AB.substring(1,2); |
| 234 | + String A = AB.substring(0,1); |
| 235 | + Character c1 = null; |
| 236 | + Character c2 = null; |
| 237 | + EinsumRewriteType t; |
| 238 | + |
| 239 | + if(AZs.size()!=0){ |
| 240 | + c1=AZs.get(0).c2; |
| 241 | + t=EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX_AZ__Z; |
| 242 | + } |
| 243 | + else if(doSumA){ |
| 244 | + if(doSumB) { |
| 245 | + t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__; |
| 246 | + } |
| 247 | + else { |
| 248 | + t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__B; |
| 249 | + c1 = AB.charAt(1); |
| 250 | + } |
| 251 | + } |
| 252 | + else if(doSumB){ |
| 253 | + t= EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__A; |
| 254 | + c1= AB.charAt(0); |
| 255 | + } |
| 256 | + else { |
| 257 | + t = EinsumRewriteType.AB_BA_B_XB_BX_A_XA_AX__AB; |
| 258 | + c1 = AB.charAt(0); |
| 259 | + c2 = AB.charAt(1); |
| 260 | + } |
| 261 | + if(c1 != null){ |
| 262 | + charToOccurences.put(c1, charToOccurences.get(c1)+1); |
| 263 | + } |
| 264 | + if(c2 != null){ |
| 265 | + charToOccurences.put(c2, charToOccurences.get(c2)+1); |
| 266 | + } |
| 267 | + HashSet<EOpNode> usedOperands = new HashSet<>(); |
| 268 | + |
| 269 | + ArrayList<EOpNode> ABs=charsToMatrices.containsKey(AB) ? charsToMatrices.get(AB) : new ArrayList<>(); |
| 270 | + ArrayList<EOpNode> BAs=charsToMatrices.containsKey(BA) ? charsToMatrices.get(BA) : new ArrayList<>(); |
| 271 | + ArrayList<EOpNode> Bs=charsToMatrices.containsKey(B) ? charsToMatrices.get(B) : new ArrayList<>(); |
| 272 | + ArrayList<EOpNode> As=charsToMatrices.containsKey(A) ? charsToMatrices.get(A) : new ArrayList<>(); |
| 273 | + |
| 274 | + usedOperands.addAll(ABs); |
| 275 | + usedOperands.addAll(BAs); |
| 276 | + usedOperands.addAll(Bs); |
| 277 | + usedOperands.addAll(As); |
| 278 | + usedOperands.addAll(XBs); |
| 279 | + usedOperands.addAll(BXs); |
| 280 | + usedOperands.addAll(XAs); |
| 281 | + usedOperands.addAll(AXs); |
| 282 | + usedOperands.addAll(AZs); |
| 283 | + |
| 284 | + for(EOpNode n : operands){ |
| 285 | + if(!usedOperands.contains(n)){ |
| 286 | + ret.add(n); |
| 287 | + }else{ |
| 288 | + if(charToOccurences != null){ |
| 289 | + charToOccurences.put(n.c1, charToOccurences.get(n.c1)-1); |
| 290 | + if(charToOccurences.get(n.c2)!= null) |
| 291 | + charToOccurences.put(n.c2, charToOccurences.get(n.c2)-1); |
| 292 | + } |
| 293 | + } |
| 294 | + } |
| 295 | + |
| 296 | + var e = new EOpNodeEinsumFuse(c1, c2, t, |
| 297 | + ABs, |
| 298 | + BAs, |
| 299 | + Bs, |
| 300 | + XBs, |
| 301 | + BXs, |
| 302 | + As, |
| 303 | + XAs, |
| 304 | + AXs, |
| 305 | + AZs |
| 306 | + ); |
| 307 | + ret.add(e); |
| 308 | + return e; |
| 309 | + } |
| 310 | +} |
| 311 | + |
0 commit comments