Skip to content

Commit f6b93e2

Browse files
small change to save lines
1 parent 3c5f86a commit f6b93e2

File tree

1 file changed

+21
-38
lines changed

1 file changed

+21
-38
lines changed

src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -492,125 +492,107 @@ private MatrixBlock ComputeEOpNode(EOpNode eOpNode, ArrayList<MatrixBlock> input
492492

493493
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
494494

495+
MatrixBlock res;
495496
switch (bin.operand){
496497
case AB_AB -> {
497-
var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
498-
return res;
498+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
499499
}
500500
case A_A -> {
501501
EnsureMatrixBlockColumnVector(left);
502502
EnsureMatrixBlockColumnVector(right);
503-
var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
504-
return res;
503+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
505504
}
506505
case a_a -> {
507506
EnsureMatrixBlockColumnVector(left);
508507
EnsureMatrixBlockColumnVector(right);
509-
var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
508+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
510509
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads);
511510
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
512-
return res;
513511
}
514512
////////////
515513
case Ba_Ba -> {
516-
var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
514+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
517515
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads);
518516
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
519-
return res;
520517
}
521518
case aB_aB -> {
522-
var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
519+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
523520
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads);
524521
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
525522
EnsureMatrixBlockColumnVector(res);
526-
return res;
527523
}
528524
case ab_ab -> {
529-
var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
525+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
530526
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads);
531527
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
532-
return res;
533528
}
534529
case ab_ba -> {
535530
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
536531
right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
537-
var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
532+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
538533
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads);
539534
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
540-
return res;
541535
}
542536
case Ba_aB -> {
543537
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
544538
right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
545-
var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
539+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
546540
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads);
547541
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
548-
return res;
549542
}
550543

551544
/////////
552545
case AB_BA -> {
553546
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
554547
right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
555-
var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
556-
return res;
548+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, right},new ScalarObject[]{}, new MatrixBlock());
557549
}
558550
case Ba_aC -> {
559-
var res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads);
560-
return res;
551+
res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads);
561552
}
562553
case aB_Ca -> {
563-
var res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(), _numThreads);
564-
return res;
554+
res = LibMatrixMult.matrixMult(right,left, new MatrixBlock(), _numThreads);
565555
}
566556
case Ba_Ca -> {
567557
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
568558
right = right.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
569-
var res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads);
570-
return res;
559+
res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads);
571560
}
572561
case aB_aC -> {
573562
ReorgOperator transpose = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
574563
left = left.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
575-
var res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads);
576-
return res;
564+
res = LibMatrixMult.matrixMult(left,right, new MatrixBlock(), _numThreads);
577565
}
578566
case A_scalar, AB_scalar -> {
579-
var res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock());
580-
return res;
567+
res = MatrixBlock.naryOperations(new SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock());
581568
}
582569
case BA_A -> {
583570
EnsureMatrixBlockRowVector(right);
584-
var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
585-
return res;
571+
res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
586572
}
587573
case Ba_a -> {
588574
EnsureMatrixBlockRowVector(right);
589-
var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
575+
res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
590576
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads);
591577
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
592-
return res;
593578
}
594579

595580
case AB_A -> {
596581
EnsureMatrixBlockColumnVector(right);
597-
var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
598-
return res;
582+
res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
599583
}
600584
case aB_a -> {
601585
EnsureMatrixBlockColumnVector(right);
602-
var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
586+
res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
603587
AggregateUnaryOperator aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads);
604588
res = (MatrixBlock) res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
605589
EnsureMatrixBlockColumnVector(res);
606-
return res;
607590
}
608591

609592
case A_B -> {
610593
EnsureMatrixBlockColumnVector(left);
611594
EnsureMatrixBlockRowVector(right);
612-
var res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
613-
return res;
595+
res = left.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), right);
614596
}
615597
case scalar_scalar -> {
616598
return new MatrixBlock(left.get(0,0)*right.get(0,0));
@@ -620,6 +602,7 @@ private MatrixBlock ComputeEOpNode(EOpNode eOpNode, ArrayList<MatrixBlock> input
620602
}
621603

622604
}
605+
return res;
623606
}
624607

625608
private static MatrixBlock ComputeEOpNodeCodegen(EOpNode eOpNode, ArrayList<MatrixBlock> inputs){

0 commit comments

Comments
 (0)