3030#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
3131#include " llvm/Support/Debug.h"
3232
33- #include " gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"
33+ // #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"
3434
3535namespace mlir {
3636namespace gc {
@@ -300,12 +300,12 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8;
300300// void *allocator(size_t size) { return std::aligned_alloc(64, size); }
301301// void deallocator(void *ptr) { std::free(ptr); }
302302
303- std::shared_ptr<ConstCacheProxy> createConstCacheProxy (size_t size) {
304- // simply allocate buffer and return
305- std::shared_ptr<void > base = std::shared_ptr<void >{
306- std::aligned_alloc (64 , size), [](void *p) { std::free (p); }};
307- return std::make_shared<ConstCacheProxy>(base, base.get (), size, true );
308- }
303+ // std::shared_ptr<ConstCacheProxy> createConstCacheProxy(size_t size) {
304+ // // simply allocate buffer and return
305+ // std::shared_ptr<void> base = std::shared_ptr<void>{
306+ // std::aligned_alloc(64, size), [](void *p) { std::free(p); }};
307+ // return std::make_shared<ConstCacheProxy>(base, base.get(), size, true);
308+ // }
309309
310310size_t divideAndCeil (size_t x, size_t y) { return (x + y - 1 ) / y; }
311311
@@ -329,12 +329,12 @@ struct constGraphTensorCacheManager {
329329 totalSize += divideAndCeil (buffersSize[i], 64 ) * 64 ;
330330 }
331331 llvm::dbgs () << " Alloc total size: " << totalSize << ' \n ' ;
332- auto base = createConstCacheProxy (totalSize);
332+ // auto base = createConstCacheProxy(totalSize);
333333 std::vector<uint64_t > globalIds (buffersSize.size ());
334334 size_t offset = 0 ;
335335 for (size_t i = 0 ; i < buffersSize.size (); i++) {
336336 llvm::dbgs () << " Alloc offset: " << offset << ' \n ' ;
337- regCachedTensor (cachedTensorGlobalId, base, offset);
337+ // regCachedTensor(cachedTensorGlobalId, base, offset);
338338 globalIds[i] = cachedTensorGlobalId;
339339 ++cachedTensorGlobalId;
340340 offset += divideAndCeil (buffersSize[i], 64 ) * 64 ;
@@ -431,11 +431,11 @@ void CST::runOnOperation() {
431431 // values of folded constant weights in original block
432432 SmallVector<Value> outputValues;
433433 Value v;
434- // TODO: solve complicated topology. Currently we only handle simple topology
435- // where one constant weight input will and only will produce one constant
436- // output and each constant weight only contributes to one constant output.
434+ // Support complicated topology.
437435 for (size_t id = 0 ; id < block.getNumArguments (); ++id) {
438436 if (constArgsIndexes.count (id) == 1 ) {
437+ // The constant ops are all single-input single-output.
438+ bool simpleTopo = true ;
439439 auto arg = block.getArgument (id);
440440 if (!isa<TensorType>(arg.getType ())) {
441441 continue ;
@@ -444,54 +444,72 @@ void CST::runOnOperation() {
444444 v = dyn_cast<Value>(arg);
445445 inputValues.push_back (v);
446446 SmallVector<Value> valuesOnTheWay = {v}; // the constant tensors
447+ std::deque<Value> dq;
448+ dq.push_back (v);
447449 // For v -> pack1 -> pack2 -> matmul, we need the type of output of pack2
448- while (!v.getUsers ().empty ()) {
449- // v.getUsers().size() should be 1
450- Operation *user = *(v.getUsers ().begin ());
451- // If user is not const or user has multiple operand, we reach the end
452- if (!isInConstantSubgraph (user) || !singleOperand (user)) {
453- outputTypes.push_back (v.getType ());
454- outputValues.push_back (v);
455- break ;
450+ while (!dq.empty ()) {
451+ v = dq.front ();
452+ dq.pop_front ();
453+ // if the children ops of v are not all constant, we end at v
454+ if (std::any_of (v.getUsers ().begin (), v.getUsers ().end (),
455+ [](Operation *child) {
456+ return !isInConstantSubgraph (child);
457+ })) {
458+ if (std::find (outputValues.begin (), outputValues.end (), v) ==
459+ outputValues.end ()) {
460+ outputTypes.push_back (v.getType ());
461+ outputValues.push_back (v);
462+ }
463+ continue ;
464+ }
465+ if (!v.hasOneUse ()) {
466+ simpleTopo = false ;
467+ }
468+ // the children ops of v are all constant, we push their results to
469+ // queue
470+ for (Operation *child : v.getUsers ()) {
471+ if (!singleOperand (child) || child->getResults ().size () > 1 ) {
472+ simpleTopo = false ;
473+ }
474+ for (OpResult result : child->getResults ()) {
475+ auto r = dyn_cast<Value>(result);
476+ dq.push_back (r);
477+ valuesOnTheWay.push_back (r);
478+ }
456479 }
457- // user should has only 1 output value
458- OpResult result = *(user->result_begin ());
459- v = dyn_cast<Value>(result);
460- valuesOnTheWay.push_back (v);
461480 }
462481
463482 // If data size of outputValue is too greater than size of inputValue, do
464483 // not fold it. Compare data size changes during traverse to find the last
465484 // op that satisfies this condition.
466- int64_t initSize =
467- getTensorSize (dyn_cast<TensorType>(valuesOnTheWay[0 ].getType ()));
468- if (!isa<TensorType>(outputTypes.back ()) ||
469- initSize * DATA_SIZE_EXPANDING_THRESHOLD <
470- getTensorSize (dyn_cast<TensorType>(outputTypes.back ()))) {
471- size_t lastIdx = 0 ;
472- for (size_t i = 1 ; i < valuesOnTheWay.size (); ++i) {
473- int64_t size =
474- getTensorSize (dyn_cast<TensorType>(valuesOnTheWay[i].getType ()));
475- if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) {
476- lastIdx = i;
485+ if (simpleTopo) {
486+ int64_t initSize =
487+ getTensorSize (dyn_cast<TensorType>(valuesOnTheWay[0 ].getType ()));
488+ if (!isa<TensorType>(outputTypes.back ()) ||
489+ initSize * DATA_SIZE_EXPANDING_THRESHOLD <
490+ getTensorSize (dyn_cast<TensorType>(outputTypes.back ()))) {
491+ size_t lastIdx = 0 ;
492+ for (size_t i = 1 ; i < valuesOnTheWay.size (); ++i) {
493+ int64_t size = getTensorSize (
494+ dyn_cast<TensorType>(valuesOnTheWay[i].getType ()));
495+ if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) {
496+ lastIdx = i;
497+ }
498+ }
499+ if (lastIdx == 0 ) { // no suitable value found
500+ inputTypes.pop_back ();
501+ outputTypes.pop_back ();
502+ inputValues.pop_back ();
503+ outputValues.pop_back ();
504+ constArgsIndexes.erase (id);
505+ } else {
506+ outputTypes.back () = valuesOnTheWay[lastIdx].getType ();
507+ outputValues.back () = valuesOnTheWay[lastIdx];
477508 }
478- }
479- if (lastIdx == 0 ) { // no suitable value found
480- inputTypes.pop_back ();
481- outputTypes.pop_back ();
482- inputValues.pop_back ();
483- outputValues.pop_back ();
484- constArgsIndexes.erase (id);
485- } else {
486- outputTypes.back () = valuesOnTheWay[lastIdx].getType ();
487- outputValues.back () = valuesOnTheWay[lastIdx];
488509 }
489510 }
490511 }
491512 }
492- if (inputTypes.size () != outputTypes.size ()) {
493- return ;
494- }
495513
496514 FunctionType foldFuncType =
497515 FunctionType::get (context, inputTypes, outputTypes);
@@ -548,30 +566,34 @@ void CST::runOnOperation() {
548566 moduleOp.push_back (foldFunc);
549567 symbolTable.insert (foldFunc);
550568
569+ // the indexes of args to the folding func.
551570 SmallVector<int32_t > foldArgs;
571+ // the indexes of folded args.
552572 SmallVector<int32_t > foldIds;
573+ // the indexes of args to the computing func.
553574 SmallVector<int32_t > computeArgs;
554575
555576 // modify the BlockArguments of block
556577 size_t oriNumArgs = block.getNumArguments ();
557- size_t argIdx = 0 ;
578+ // Add the folded args to the end of BlockArguments list
579+ for (size_t id = 0 ; id < outputValues.size (); ++id) {
580+ auto loc = block.getArgument (id).getLoc ();
581+ BlockArgument foldArg =
582+ block.insertArgument (oriNumArgs + id, outputTypes[id], loc);
583+ outputValues[id].replaceUsesWithIf (foldArg, [&](OpOperand &val) {
584+ Operation *op = val.getOwner ();
585+ return op->getBlock () == █
586+ });
587+ foldIds.push_back (id + oriNumArgs);
588+ }
589+ // Erase the operations on constant args
558590 for (size_t id = 0 ; id < oriNumArgs; ++id) {
559591 if (constArgsIndexes.count (id) == 1 ) {
560592 foldArgs.push_back (id);
561- foldIds.push_back (argIdx + oriNumArgs);
562- computeArgs.push_back (argIdx + oriNumArgs);
563- auto loc = block.getArgument (id).getLoc ();
564- BlockArgument foldArg =
565- block.insertArgument (id, outputTypes[argIdx], loc);
566- outputValues[argIdx].replaceUsesWithIf (foldArg, [&](OpOperand &val) {
567- Operation *op = val.getOwner ();
568- return op->getBlock () == █
569- });
570-
571593 std::deque<Value> dq;
572594 SmallVector<Operation *> opsToErase;
573595 std::unordered_set<Operation *> opsToEraseSet;
574- dq.push_back (block.getArgument (id + 1 ));
596+ dq.push_back (block.getArgument (id));
575597 while (!dq.empty ()) {
576598 Value v = dq.front ();
577599 dq.pop_front ();
@@ -586,16 +608,26 @@ void CST::runOnOperation() {
586608 opsToEraseSet.insert (op);
587609 }
588610 }
589-
590611 for (auto it = opsToErase.rbegin (); it != opsToErase.rend (); ++it) {
591612 (*it)->erase ();
592613 }
593- block.eraseArgument (id + 1 );
594- ++argIdx;
595614 } else {
596615 computeArgs.push_back (id);
597616 }
598617 }
618+ // Erase the constant args in BlockArguments list
619+ llvm::BitVector argsToErase;
620+ for (size_t id = 0 ; id < oriNumArgs; ++id) {
621+ if (constArgsIndexes.count (id) == 1 ) {
622+ argsToErase.push_back (true );
623+ } else {
624+ argsToErase.push_back (false );
625+ }
626+ }
627+ for (size_t id = 0 ; id < outputValues.size (); ++id) {
628+ argsToErase.push_back (false );
629+ }
630+ block.eraseArguments (argsToErase);
599631
600632 for (auto id : foldIds) {
601633 foldArgs.insert (foldArgs.end (), id);
@@ -604,6 +636,9 @@ void CST::runOnOperation() {
604636 addGlobalI32Array (moduleOp, moduleOp.getLoc (), builder, " __fold_args" ,
605637 foldArgs);
606638
639+ for (auto id : foldIds) {
640+ computeArgs.insert (computeArgs.end (), id);
641+ }
607642 computeArgs.insert (computeArgs.begin (), computeArgs.size ());
608643 addGlobalI32Array (moduleOp, moduleOp.getLoc (), builder, " __compute_args" ,
609644 computeArgs);
0 commit comments