@@ -538,6 +538,126 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
538538 let hasCanonicalizer = 1;
539539}
540540
541+ //===----------------------------------------------------------------------===//
542+ // Op definition for ElementwiseOp
543+ //===----------------------------------------------------------------------===//
544+ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
545+ AttrSizedOperandSegments]> {
546+ let summary = [{ Performs element-wise operation }];
547+ let description = [{
548+ The attribute `kind` describes arithmetic operation to perform. The
549+ operation kind can be unary (e.g. max), binary (e.g. add) or ternary
550+ (e.g. select).
551+
552+ By default, all indexing maps are identities. In the case of default
553+ indexing map, all input and output shapes must match. The number of dims in
554+ each of the identity maps is equal to the rank of the output type.
555+
556+ Affine-maps for operands and result are required to be provided by the user
557+ when a transpose and/or broadcast is needed on any operand. When a map is not
558+ provided, default identity maps are inferred for each operand.
559+
560+ Iterator-types are always all `parallel`.
561+ Iterator-types are needed for constructing the underlying structured op.
562+
563+ The number of dims of the iterator-types are inferred from the rank of
564+ the result type.
565+
566+ Example:
567+
568+ Defining a unary linalg.elemwise with default indexing-map:
569+ ```mlir
570+ %exp = linalg.elemwise
571+ kind=#linalg.elemwise_kind<exp>
572+ ins(%x : tensor<4x16x8xf32>)
573+ outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
574+ ```
575+
576+ Defining a binary linalg.elemwise with user-defined indexing-map:
577+ ```mlir
578+ %add = linalg.elemwise
579+ kind=#linalg.elemwise_kind<add>
580+ indexing_maps = [#transpose, #broadcast, #identity]
581+ ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
582+ outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
583+ ```
584+ }];
585+
586+ let arguments = (ins
587+ Variadic<AnyType>:$inputs,
588+ Variadic<AnyShaped>:$outputs,
589+ ElementwiseKindAttr:$kind,
590+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
591+ );
592+
593+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
594+ let regions = (region AnyRegion:$region);
595+ let skipDefaultBuilders = 1;
596+
597+ let builders = [
598+ OpBuilder<
599+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
600+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
601+ [{
602+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
603+ attributes, ElementwiseOp::getRegionBuilder());
604+ }]>
605+ ];
606+
607+ let hasCustomAssemblyFormat = 1;
608+ let hasFolder = 1;
609+ let hasVerifier = 1;
610+
611+ let extraClassDeclaration = structuredOpsBaseDecls # [{
612+ /// Get the arity enum corresponding to the kind of op, e.g. if arg is
613+ /// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`.
614+ static ElementwiseArityGroup getArityGroup(ElementwiseKind n);
615+
616+ /// Both user-specified and default indexing map will always depend on
617+ /// the current Op instance.
618+ static bool hasDynamicIndexingMaps() { return true; }
619+
620+ /// Implements the block region builder for the elementwiseOp. This is
621+ /// called by the 'fillStructuredOpRegion'.
622+ static void regionBuilder(ImplicitLocOpBuilder &b,
623+ Block &block, ArrayRef<NamedAttribute> attrs);
624+
625+ static std::function<void(ImplicitLocOpBuilder &,
626+ Block &, ArrayRef<NamedAttribute>)>
627+ getRegionBuilder() {
628+ return regionBuilder;
629+ }
630+
631+ /// Returns rank of the result tensor/memref. Useful for knowing
632+ /// the dimensionality of the iteration space when others means
633+ /// are not possible e.g. absence of user-provided indexing map.
634+ unsigned getResultRank() {
635+ Value output = getDpsInitOperand(0)->get();
636+ ShapedType shapedType = llvm::cast<ShapedType>(output.getType());
637+ return shapedType.getRank();
638+ }
639+
640+ /// Returns N 'parallel' iterator types where N is rank of result.
641+ SmallVector<utils::IteratorType> getIteratorTypesArray();
642+
643+ /// The default indexing maps are identities.
644+ /// There will be N+1 such maps, where N is the arity of the Op.
645+ static SmallVector<AffineMap>
646+ getDefaultIndexingMaps(unsigned NumMaps, unsigned numDims,
647+ MLIRContext *context);
648+
649+ /// Destination passing style interface method.
650+ ::mlir::MutableOperandRange getDpsInitsMutable() {
651+ return getOutputsMutable();
652+ }
653+
654+ // Generic methods.
655+ std::string getLibraryCallName() {
656+ return generateLibraryCallName(getOperation());
657+ }
658+ }];
659+ }
660+
541661//===----------------------------------------------------------------------===//
542662// Op definition for MatmulOp
543663//===----------------------------------------------------------------------===//
0 commit comments