|
20 | 20 | package org.apache.sysds.common; |
21 | 21 |
|
22 | 22 | import org.apache.sysds.lops.*; |
23 | | - |
24 | 23 | import org.apache.sysds.common.Types.OpOp1; |
25 | 24 | import org.apache.sysds.hops.FunctionOp; |
26 | 25 |
|
27 | | -import java.util.EnumSet; |
28 | 26 | import java.util.HashMap; |
29 | 27 | import java.util.Map; |
30 | 28 |
|
@@ -390,8 +388,8 @@ public enum Opcodes { |
390 | 388 |
|
391 | 389 | BINUAGGCHAIN("binuaggchain", InstructionType.BinUaggChain), |
392 | 390 |
|
393 | | - CASTDTM("castdtm", InstructionType.Cast), |
394 | | - CASTDTF("castdtf", InstructionType.Cast), |
| 391 | + CASTDTM("castdtm", InstructionType.Variable, InstructionType.Cast), |
| 392 | + CASTDTF("castdtf", InstructionType.Variable, InstructionType.Cast), |
395 | 393 |
|
396 | 394 | //FED Opcodes |
397 | 395 | FEDINIT("fedinit", InstructionType.Init); |
@@ -427,7 +425,7 @@ public enum Opcodes { |
427 | 425 | private static final Map<String, Opcodes> _lookupMap = new HashMap<>(); |
428 | 426 |
|
429 | 427 | static { |
430 | | - for (Opcodes op : EnumSet.allOf(Opcodes.class)) { |
| 428 | + for (Opcodes op : Opcodes.values()) { |
431 | 429 | if (op._name != null) { |
432 | 430 | _lookupMap.put(op._name.toLowerCase(), op); |
433 | 431 | } |
@@ -456,19 +454,17 @@ public static InstructionType getTypeByOpcode(String opcode, Types.ExecType type |
456 | 454 | if (opcode == null || opcode.trim().isEmpty()) { |
457 | 455 | return null; |
458 | 456 | } |
459 | | - for (Opcodes op : Opcodes.values()) { |
460 | | - if (op.toString().equalsIgnoreCase(opcode.trim())) { |
461 | | - switch (type) { |
462 | | - case SPARK: |
463 | | - return (op.getSpType() != null) ? op.getSpType() : op.getType(); |
464 | | - case FED: |
465 | | - return (op.getFedType() != null) ? op.getFedType() : op.getType(); |
466 | | - default: |
467 | | - return op.getType(); |
468 | | - } |
| 457 | + Opcodes op = _lookupMap.get(opcode.trim().toLowerCase()); |
| 458 | + if( op != null ) { |
| 459 | + switch (type) { |
| 460 | + case SPARK: |
| 461 | + return (op.getSpType() != null) ? op.getSpType() : op.getType(); |
| 462 | + case FED: |
| 463 | + return (op.getFedType() != null) ? op.getFedType() : op.getType(); |
| 464 | + default: |
| 465 | + return op.getType(); |
469 | 466 | } |
470 | 467 | } |
471 | 468 | return null; |
472 | 469 | } |
473 | 470 | } |
474 | | - |
0 commit comments