Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions circle-mlir/circle-mlir/lib/dialect/src/CircleDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseType(type))
return failure();
auto fnType = type.dyn_cast<FunctionType>();
auto fnType = mlir::dyn_cast<FunctionType>(type);
if (!fnType)
{
parser.emitError(loc, "expected function type");
Expand Down Expand Up @@ -134,7 +134,7 @@ bool VerifyOperandsHaveSameShapesOrBroadcastableShape(Operation *op, ArrayRef<un

for (unsigned index : indices)
{
ShapedType shaped_type = op->getOperand(index).getType().dyn_cast<ShapedType>();
ShapedType shaped_type = mlir::dyn_cast<ShapedType>(op->getOperand(index).getType());
if (!shaped_type || !shaped_type.hasRank())
{
// Marks that we have an unknown rank input.
Expand Down Expand Up @@ -212,8 +212,8 @@ bool EqualsZero(Value value)
return false;
}

Type element_type = value.getType().cast<ShapedType>().getElementType();
if (element_type.isa<FloatType>())
Type element_type = mlir::cast<ShapedType>(value.getType()).getElementType();
if (mlir::isa<FloatType>(element_type))
{
return constant.getSplatValue<APFloat>().isZero();
}
Expand Down Expand Up @@ -315,7 +315,7 @@ bool ExtractConstantValues(mlir::Value &input, std::vector<int64_t> &values)

void CIRDialect::printType(Type type, DialectAsmPrinter &os) const
{
if (type.isa<ControlType>())
if (mlir::isa<ControlType>(type))
{
os << "control";
return;
Expand Down Expand Up @@ -357,7 +357,7 @@ namespace
// Returns true if it is a shaped type of f32 elements.
inline bool IsF32ShapedType(Type t)
{
if (auto shaped_type = t.dyn_cast_or_null<ShapedType>())
if (auto shaped_type = mlir::dyn_cast_or_null<ShapedType>(t))
{
return shaped_type.getElementType().isF32();
}
Expand All @@ -367,7 +367,7 @@ inline bool IsF32ShapedType(Type t)
// Returns true if it is a shaped type of i64 elements.
inline bool IsI64ShapedType(Type t)
{
if (auto shaped_type = t.dyn_cast_or_null<ShapedType>())
if (auto shaped_type = mlir::dyn_cast_or_null<ShapedType>(t))
{
return shaped_type.getElementType().isInteger(64);
}
Expand All @@ -391,11 +391,11 @@ namespace

bool InputOutputHasSameShape(mlir::Type input_type, mlir::Type output_type)
{
auto input_shaped_type = input_type.dyn_cast_or_null<ShapedType>();
auto input_shaped_type = mlir::dyn_cast_or_null<ShapedType>(input_type);
if (!input_shaped_type || !input_shaped_type.hasStaticShape())
return false;

auto output_shaped_type = output_type.dyn_cast_or_null<ShapedType>();
auto output_shaped_type = mlir::dyn_cast_or_null<ShapedType>(output_type);
if (!output_shaped_type || !output_shaped_type.hasStaticShape())
return false;

Expand Down Expand Up @@ -500,13 +500,13 @@ Operation *CIRDialect::materializeConstant(OpBuilder &builder, Attribute value,
{
// If this is a constant bytes attribute or the result type doesn't match the
// attribute type, then generate a tfl.pseudo_const.
if (value.isa<ConstBytesAttr>() ||
(value.isa<ElementsAttr>() && value.cast<ElementsAttr>().getType() != type))
return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
if (mlir::isa<ConstBytesAttr>(value) ||
(mlir::isa<ElementsAttr>(value) && mlir::cast<ElementsAttr>(value).getType() != type))
return builder.create<ConstOp>(loc, type, mlir::cast<ElementsAttr>(value));
if (ConstOp::isBuildableWith(value, type))
return builder.create<ConstOp>(loc, type, value.cast<ElementsAttr>());
return builder.create<ConstOp>(loc, type, mlir::cast<ElementsAttr>(value));
if (NoValueOp::isBuildableWith(value, type))
return builder.create<NoValueOp>(loc, type, value.cast<UnitAttr>());
return builder.create<NoValueOp>(loc, type, mlir::cast<UnitAttr>(value));
return nullptr;
}

Expand Down
10 changes: 5 additions & 5 deletions circle-mlir/circle-mlir/lib/dialect/src/ConstFold.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1, Attribute oper
operand2.dyn_cast_or_null<DenseElementsAttr>())
{
return ConstFoldBinaryOpDenseDense<AttrElementT, ElementValueT>(
result_type, operand1.cast<DenseElementsAttr>(), operand2.cast<DenseElementsAttr>(),
result_type, mlir::cast<DenseElementsAttr>(operand1), mlir::cast<DenseElementsAttr>(operand2),
calculate);
}

Expand All @@ -169,13 +169,13 @@ Attribute ConstFoldBinaryOp(Type result_type, ArrayRef<Attribute> operands,
{
// Note: All types are wrapped in tensor types in Circle. E.g., f32 is
// represented as tensor<f32>. So we are only handling tensor types here.
auto type = result_type.dyn_cast<ShapedType>();
auto type = mlir::dyn_cast<ShapedType>(result_type);
if (!type)
return {};

auto elemType = type.getElementType();

if (elemType.isa<FloatType>())
if (mlir::isa<FloatType>(elemType))
return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1], float_calculate);

if (elemType.isSignlessInteger())
Expand All @@ -191,12 +191,12 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
llvm::function_ref<APFloat(APFloat)> calculate)
{
assert(IsF32ShapedType(result_type));
auto result_shape_type = result_type.cast<ShapedType>();
auto result_shape_type = mlir::cast<ShapedType>(result_type);

if (!result_shape_type.hasStaticShape())
return {};

if (auto dense_elements = operand.dyn_cast_or_null<DenseElementsAttr>())
if (auto dense_elements = mlir::dyn_cast_or_null<DenseElementsAttr>(operand))
{
SmallVector<APFloat, 16> new_values;
const int num_elements = result_shape_type.getNumElements();
Expand Down
6 changes: 3 additions & 3 deletions circle-mlir/circle-mlir/lib/dialect/src/NameUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ std::string GetNameFromLoc(Location loc)
{
Location curr_loc = locs.pop_back_val();

if (auto name_loc = curr_loc.dyn_cast<NameLoc>())
if (auto name_loc = mlir::dyn_cast<NameLoc>(curr_loc))
{
// Add name in NameLoc. For NameLoc we also account for names due to ops
// in functions where the op's name is first.
Expand All @@ -54,13 +54,13 @@ std::string GetNameFromLoc(Location loc)
}
continue;
}
else if (auto call_loc = curr_loc.dyn_cast<CallSiteLoc>())
else if (auto call_loc = mlir::dyn_cast<CallSiteLoc>(curr_loc))
{
// Use location of the Callee to generate the name.
locs.push_back(call_loc.getCallee());
continue;
}
else if (auto fused_loc = curr_loc.dyn_cast<FusedLoc>())
else if (auto fused_loc = mlir::dyn_cast<FusedLoc>(curr_loc))
{
// Push all locations in FusedLoc in reverse order, so locations are
// visited based on order in FusedLoc.
Expand Down
12 changes: 6 additions & 6 deletions circle-mlir/circle-mlir/lib/dialect/src/ops/CastOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,23 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor)
}

// For now, only supports cast for the integer/float input type.
auto elements_attr = operands[0].dyn_cast_or_null<mlir::DenseElementsAttr>();
auto elements_attr = mlir::dyn_cast_or_null<mlir::DenseElementsAttr>(operands[0]);
if (!elements_attr)
{
return nullptr;
}

auto result_element_type = getType().cast<ShapedType>().getElementType();
auto operand_element_type = getInput().getType().cast<ShapedType>().getElementType();
auto operand_int_type = operand_element_type.dyn_cast<IntegerType>();
auto result_element_type = mlir::cast<ShapedType>(getType()).getElementType();
auto operand_element_type = mlir::cast<ShapedType>(getInput().getType()).getElementType();
auto operand_int_type = mlir::dyn_cast<IntegerType>(operand_element_type);
if (!result_element_type || !operand_element_type)
{
return nullptr;
}

if (mlir::isa<mlir::IntegerType>(result_element_type))
{
auto result_int_type = result_element_type.dyn_cast<IntegerType>();
auto result_int_type = mlir::dyn_cast<IntegerType>(result_element_type);
const int output_bitwidth = result_int_type.getWidth();
// check for INT64 <--> INT32
if (operand_int_type)
Expand Down Expand Up @@ -97,7 +97,7 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor)
else if (mlir::isa<mlir::FloatType>(result_element_type))
{
// Refer to https://llvm.org/doxygen/classllvm_1_1APFloat.html
auto result_float_type = result_element_type.dyn_cast<FloatType>();
auto result_float_type = mlir::dyn_cast<FloatType>(result_element_type);
// To get the correct semantics of floating point from the type of this CastOp
const llvm::fltSemantics &semantics = result_float_type.getFloatSemantics();
auto cast = [&](const llvm::APInt &value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ bool IsConcatenationOpConstFoldable(ConcatenationOp op, ArrayRef<Attribute> oper
return false;

return llvm::all_of(
operands, [](Attribute operand) { return operand && operand.isa<DenseElementsAttr>(); });
operands, [](Attribute operand) { return operand && mlir::isa<DenseElementsAttr>(operand); });
}

DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
Expand Down
4 changes: 2 additions & 2 deletions circle-mlir/circle-mlir/lib/dialect/src/ops/ConstOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ bool ConstOp::isBuildableWith(Attribute value, Type type)
if (!typedAttr || typedAttr.getType() != type)
return false;
// Integer values must be signless.
if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
if (mlir::isa<IntegerType>(type) && !mlir::cast<IntegerType>(type).isSignless())
return false;
// Integer, float, and element attributes are buildable.
return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
return mlir::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
}

} // namespace Circle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ LogicalResult FullyConnectedOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFold
return failure();

// Bias tensor is optional.
const bool has_bias = !(!getBias() || getBias().getType().isa<NoneType>());
const bool has_bias = !(!getBias() || mlir::isa<NoneType>(getBias().getType()));

// Get the tensors.
DenseElementsAttr input_tensor, weights_tensor, bias_tensor;
Expand Down
2 changes: 1 addition & 1 deletion circle-mlir/circle-mlir/lib/dialect/src/ops/NoValueOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ OpFoldResult NoValueOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }

bool NoValueOp::isBuildableWith(Attribute value, Type type)
{
return value.isa<UnitAttr>() && type.isa<NoneType>();
return mlir::isa<UnitAttr>(value) && mlir::isa<NoneType>(type);
}

} // namespace Circle
Expand Down
23 changes: 12 additions & 11 deletions circle-mlir/circle-mlir/lib/export/src/CircleExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ static circle::TensorType GetCircleType(mlir::Type type, bool is_signed = true)
{
return circle::TensorType_FLOAT64;
}
else if (auto itype = type.dyn_cast<mlir::IntegerType>())
else if (auto itype = mlir::dyn_cast<mlir::IntegerType>(type))
{
switch (itype.getWidth())
{
Expand Down Expand Up @@ -313,7 +313,7 @@ std::optional<BufferOffset<circle::Buffer>> Translator::BuildBuffer(mlir::Value
{
// arith::ConstantOp have ElementAttr at this point due to validation of the
// Circle module.
attr = cst.getValue().cast<mlir::ElementsAttr>();
attr = mlir::cast<mlir::ElementsAttr>(cst.getValue());
}
else if (auto cst = llvm::dyn_cast<mlir::Circle::ConstOp>(inst))
{
Expand All @@ -325,7 +325,7 @@ std::optional<BufferOffset<circle::Buffer>> Translator::BuildBuffer(mlir::Value
return empty_buffer_;
}

auto type = value.getType().cast<mlir::TensorType>();
auto type = mlir::cast<mlir::TensorType>(value.getType());
circle::TensorType circle_element_type = GetCircleType(type.getElementType());

BYTES data;
Expand Down Expand Up @@ -383,7 +383,7 @@ std::optional<BufferOffset<circle::Tensor>> Translator::BuildTensor(
mlir::Value value, const std::string &name, unsigned buffer_idx,
const std::optional<BufferOffset<circle::QuantizationParameters>> &quant_parameters)
{
auto type = value.getType().cast<mlir::TensorType>();
auto type = mlir::cast<mlir::TensorType>(value.getType());

// Circle requires tensor shape only for the inputs and constants.
// However, we output all known shapes for better round-tripping
Expand Down Expand Up @@ -413,8 +413,9 @@ std::optional<BufferOffset<circle::Tensor>> Translator::BuildTensor(
// Const op can have a result of dynamic shaped type (e.g. due to constant
// folding), but we can still derive the shape of a constant tensor for
// its attribute type.
auto tensor_attr = inst->getAttr("value").cast<mlir::TypedAttr>();
llvm::ArrayRef<int64_t> shape_ref = tensor_attr.getType().cast<mlir::TensorType>().getShape();
auto tensor_attr = mlir::cast<mlir::TypedAttr>(inst->getAttr("value"));
llvm::ArrayRef<int64_t> shape_ref =
mlir::cast<mlir::TensorType>(tensor_attr.getType()).getShape();
if (mlir::failed(check_shape(shape_ref)))
return std::nullopt;

Expand Down Expand Up @@ -495,7 +496,7 @@ BufferOffset<circle::Operator> Translator::BuildCustomOperator(Operation *inst,
const std::vector<int32_t> &results)
{
const std::string attrs =
op.getCustomOption().cast<mlir::Circle::ConstBytesAttr>().getValue().str();
mlir::cast<mlir::Circle::ConstBytesAttr>(op.getCustomOption()).getValue().str();
std::vector<uint8_t> custom_option_vector(attrs.size());
memcpy(custom_option_vector.data(), attrs.data(), attrs.size());
auto opcode_index = GetOpcodeIndex(op.getCustomCode().str(), circle::BuiltinOperator_CUSTOM);
Expand Down Expand Up @@ -616,7 +617,7 @@ void Translator::InitializeNamesFromAttribute(mlir::func::FuncOp fn, bool *has_i
assert(inputNames.size() == fn.getArguments().size());
for (const auto &it : llvm::enumerate(fn.getArguments()))
{
auto strattr = inputNames[it.index()].cast<mlir::StringAttr>();
auto strattr = mlir::cast<mlir::StringAttr>(inputNames[it.index()]);
name_mapper_.InitOpName(it.value(), strattr);
}
*has_input_attr = true;
Expand All @@ -628,7 +629,7 @@ void Translator::InitializeNamesFromAttribute(mlir::func::FuncOp fn, bool *has_i
assert(outputNames.size() == term->getOperands().size());
for (const auto &it : llvm::enumerate(term->getOperands()))
{
auto strattr = outputNames[it.index()].cast<mlir::StringAttr>();
auto strattr = mlir::cast<mlir::StringAttr>(outputNames[it.index()]);
name_mapper_.InitOpName(it.value(), strattr);
}
*has_input_attr = true;
Expand All @@ -651,7 +652,7 @@ Translator::BuildSubGraph(const std::string &name, mlir::Region *region, const i
auto build_tensor_and_buffer = [&](mlir::Value value, const int subgraph_index,
const std::string &tensor_name) {
// NoneType represents optional and may be skipped here.
if (value.getType().isa<mlir::NoneType>())
if (mlir::isa<mlir::NoneType>(value.getType()))
{
return true;
}
Expand Down Expand Up @@ -754,7 +755,7 @@ Translator::BuildSubGraph(const std::string &name, mlir::Region *region, const i
operands.reserve(real_inst->getNumOperands());
for (auto operand : real_inst->getOperands())
{
if (operand.getType().isa<mlir::NoneType>())
if (mlir::isa<mlir::NoneType>(operand.getType()))
operands.push_back(kCircleOptionalTensor);
else
operands.push_back(tensor_index_map.lookup(operand));
Expand Down
8 changes: 4 additions & 4 deletions circle-mlir/circle-mlir/lib/export/src/OpOrArgNameMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ bool OpOrArgNameMapper::IsUnique(llvm::StringRef name) { return true; }

std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val)
{
if (auto *op = op_or_val.dyn_cast<mlir::Operation *>())
if (auto *op = mlir::dyn_cast<mlir::Operation *>(op_or_val))
{
// NOTE stop for debug version to find out if there is any Op for this case
assert(false);
Expand All @@ -131,14 +131,14 @@ std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val)
// generated using the op type.
return std::string(op->getName().getStringRef());
}
auto val = op_or_val.dyn_cast<mlir::Value>();
auto val = mlir::dyn_cast<mlir::Value>(op_or_val);
auto name_from_loc = mlir::GetNameFromLoc(val.getLoc());
if (!name_from_loc.empty())
return name_from_loc;
// If the location is none of the expected types, then simply use name
// generated using the op type. Follow TF convention and append the result
// index unless 0.
if (auto result = val.dyn_cast<mlir::OpResult>())
if (auto result = mlir::dyn_cast<mlir::OpResult>(val))
{
auto name_str = result.getOwner()->getName().getStringRef().str();
auto value_op = val.getDefiningOp();
Expand All @@ -154,7 +154,7 @@ std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val)
return std::string(name_str);
}
// Use the ASM syntax for BlockArgument
if (auto arg = val.dyn_cast<mlir::BlockArgument>())
if (auto arg = mlir::dyn_cast<mlir::BlockArgument>(val))
{
return "arg" + std::to_string(arg.getArgNumber());
}
Expand Down
Loading