@@ -1406,169 +1406,6 @@ AOTITorchError aoti_torch_zero_(AtenTensorHandle tensor) {
14061406 });
14071407}
14081408
1409- static StableIValue from_ivalue (
1410- const c10::TypePtr& type,
1411- const c10::IValue& ivalue) {
1412- switch (type->kind ()) {
1413- case c10::TypeKind::TensorType: {
1414- AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle (
1415- std::move (const_cast <at::Tensor&>(ivalue.toTensor ())));
1416- return torch::stable::detail::from (ath);
1417- }
1418- case c10::TypeKind::IntType: {
1419- return torch::stable::detail::from (ivalue.toInt ());
1420- }
1421- case c10::TypeKind::FloatType: {
1422- return torch::stable::detail::from (ivalue.toDouble ());
1423- }
1424- case c10::TypeKind::BoolType: {
1425- return torch::stable::detail::from (ivalue.toBool ());
1426- }
1427- case c10::TypeKind::ScalarTypeType: {
1428- return torch::stable::detail::from (ivalue.toScalarType ());
1429- }
1430- case c10::TypeKind::DeviceObjType: {
1431- return torch::stable::detail::from (ivalue.toDevice ());
1432- }
1433- case c10::TypeKind::LayoutType: {
1434- return torch::stable::detail::from (ivalue.toLayout ());
1435- }
1436- case c10::TypeKind::MemoryFormatType: {
1437- return torch::stable::detail::from (ivalue.toMemoryFormat ());
1438- }
1439- case c10::TypeKind::OptionalType: {
1440- auto inner_type = type->castRaw <at::OptionalType>()->getElementType ();
1441-
1442- // ideally, if we had the C++ type corresponding to inner_type, which we
1443- // will denote as inner_type::t (does not actually exist), we would be
1444- // able to follow the patterned semantic of every other case here in one
1445- // line:
1446- //
1447- // return
1448- // torch::stable::detail::from<std::optional<inner_type::t>>(ivalue.toInnerTypeT()));
1449- //
1450- // BUT we do NOT have that type inner_type::t readily available, so we
1451- // will manually unwrap and recursively call. This implementation MUST
1452- // be kept in sync with torch::stable::detail::from<std::optional<T>>
1453- // function in torch/csrc/stable/stableivalue_conversions.h
1454- if (ivalue.isNone ()) {
1455- return torch::stable::detail::from (std::nullopt );
1456- }
1457- StableIValue* sivp = new StableIValue (from_ivalue (inner_type, ivalue));
1458- return torch::stable::detail::from (sivp);
1459- }
1460- default : {
1461- TORCH_CHECK (
1462- false ,
1463- " Not yet supported conversion from IValue to StableIValue for schema type: " ,
1464- type->str ());
1465- }
1466- }
1467- }
1468-
1469- static c10::IValue to_ivalue (
1470- const c10::TypePtr& type,
1471- const StableIValue stable_ivalue) {
1472- switch (type->kind ()) {
1473- case c10::TypeKind::TensorType: {
1474- auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle (
1475- torch::stable::detail::to<AtenTensorHandle>(stable_ivalue));
1476- return (c10::IValue (*torch::aot_inductor::tensor_handle_to_tensor_pointer (
1477- ret_raiiath.get ())));
1478- }
1479- case c10::TypeKind::IntType: {
1480- return c10::IValue (torch::stable::detail::to<int64_t >(stable_ivalue));
1481- }
1482- case c10::TypeKind::FloatType: {
1483- return c10::IValue (torch::stable::detail::to<double >(stable_ivalue));
1484- }
1485- case c10::TypeKind::BoolType: {
1486- return c10::IValue (torch::stable::detail::to<bool >(stable_ivalue));
1487- }
1488- case c10::TypeKind::ScalarTypeType: {
1489- return c10::IValue (
1490- torch::stable::detail::to<c10::ScalarType>(stable_ivalue));
1491- }
1492- case c10::TypeKind::DeviceObjType: {
1493- return c10::IValue (torch::stable::detail::to<c10::Device>(stable_ivalue));
1494- }
1495- case c10::TypeKind::LayoutType: {
1496- return c10::IValue (torch::stable::detail::to<c10::Layout>(stable_ivalue));
1497- }
1498- case c10::TypeKind::MemoryFormatType: {
1499- return c10::IValue (
1500- torch::stable::detail::to<c10::MemoryFormat>(stable_ivalue));
1501- }
1502- case c10::TypeKind::OptionalType: {
1503- auto inner_type = type->castRaw <at::OptionalType>()->getElementType ();
1504-
1505- // ideally, if we had the C++ type corresponding to inner_type, which we
1506- // will denote as inner_type::t (does not actually exist), we would be
1507- // able to follow the patterned semantic of every other case here in one
1508- // line:
1509- //
1510- // return
1511- // c10::IValue(torch::stable::detail::to<std::optional<inner_type::t>>(stable_ivalue));
1512- //
1513- // BUT we do NOT have that type inner_type::t readily available, so we
1514- // will manually unwrap and recursively call. This implementation MUST
1515- // be kept in sync with the torch::stable::detail::to<T> function in
1516- // torch/csrc/stable/stableivalue_conversions.h
1517- if (stable_ivalue == torch::stable::detail::from (std::nullopt )) {
1518- return c10::IValue ();
1519- }
1520- auto sivp = torch::stable::detail::to<StableIValue*>(stable_ivalue);
1521- auto ival = to_ivalue (inner_type, *sivp);
1522- delete sivp;
1523- return ival;
1524- }
1525- default : {
1526- TORCH_CHECK (
1527- false ,
1528- " Not yet supported conversion from StableIValue to IValue for schema type: " ,
1529- type->str ());
1530- }
1531- }
1532- }
1533-
1534- class StableIValueBoxedKernel : public c10 ::OperatorKernel {
1535- public:
1536- StableIValueBoxedKernel (void (*fn)(StableIValue*, uint64_t , uint64_t ))
1537- : fn_(fn) {}
1538-
1539- void operator ()(
1540- const c10::OperatorHandle& op,
1541- c10::DispatchKeySet keyset,
1542- torch::jit::Stack* stack) {
1543- const auto & schema = op.schema ();
1544- const auto num_returns = schema.returns ().size ();
1545- const auto num_arguments = schema.arguments ().size ();
1546-
1547- auto ministack =
1548- std::make_unique<StableIValue[]>(std::max (num_arguments, num_returns));
1549-
1550- for (const auto idx : c10::irange (num_arguments)) {
1551- const auto ministack_idx = num_arguments - idx - 1 ;
1552- const c10::TypePtr& arg_type = schema.arguments ()[ministack_idx].type ();
1553- ministack[ministack_idx] = from_ivalue (arg_type, torch::jit::pop (stack));
1554- }
1555-
1556- // boxed function is going to take a stack of StableIValues, cast them to
1557- // our schema values, and run the function and modify the StableIValue stack
1558- fn_ (ministack.get (), num_arguments, num_returns);
1559-
1560- // read the output from the end of the stack and wrap that back into
1561- // IValue from StableIValue
1562- for (size_t idx = 0 ; idx < num_returns; idx++) {
1563- const c10::TypePtr& ret_type = schema.returns ()[idx].type ();
1564- torch::jit::push (stack, to_ivalue (ret_type, ministack[idx]));
1565- }
1566- }
1567-
1568- private:
1569- void (*fn_)(StableIValue*, uint64_t , uint64_t );
1570- };
1571-
15721409AOTITorchError aoti_torch_library_init_impl (
15731410 const char * ns,
15741411 const char * k,
@@ -1618,18 +1455,6 @@ AOTITorchError aoti_torch_library_init_fragment(
16181455 });
16191456}
16201457
1621- AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl (
1622- TorchLibraryHandle self,
1623- const char * name,
1624- void (*fn)(StableIValue*, uint64_t , uint64_t )) {
1625- AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE ({
1626- reinterpret_cast <torch::Library*>(self)->impl (
1627- name,
1628- torch::CppFunction::makeFromBoxedFunctor (
1629- std::make_unique<StableIValueBoxedKernel>(fn)));
1630- });
1631- }
1632-
16331458AOTI_TORCH_EXPORT AOTITorchError
16341459aoti_torch_library_def (TorchLibraryHandle self, const char * name) {
16351460 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE (
@@ -1642,40 +1467,6 @@ aoti_torch_delete_library_object(TorchLibraryHandle tlh) {
16421467 { delete reinterpret_cast <torch::Library*>(tlh); });
16431468}
16441469
1645- AOTITorchError aoti_torch_call_dispatcher (
1646- const char * opName,
1647- const char * overloadName,
1648- StableIValue* stack) {
1649- AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE ({
1650- const auto op =
1651- c10::Dispatcher::singleton ().findSchemaOrThrow (opName, overloadName);
1652- const auto & schema = op.schema ();
1653- const auto num_returns = schema.returns ().size ();
1654- const auto num_arguments = schema.arguments ().size ();
1655-
1656- torch::jit::Stack ivalue_stack;
1657- // we will only need max(num_args, num_returns)
1658- ivalue_stack.reserve (std::max (num_arguments, num_returns));
1659-
1660- // convert StableIValue stack to c10::IValue stack
1661- for (const auto idx : c10::irange (num_arguments)) {
1662- auto stable_ivalue = stack[idx];
1663- auto arg_type = schema.arguments ()[idx].type ();
1664- torch::jit::push (ivalue_stack, to_ivalue (arg_type, stable_ivalue));
1665- }
1666-
1667- op.callBoxed (ivalue_stack);
1668-
1669- // there should then be num_returns IValues on the stack, which
1670- // we will convert to StableIValue and repopulate user input stack
1671- for (const auto idx : c10::irange (num_returns)) {
1672- const auto stack_idx = num_returns - idx - 1 ;
1673- const c10::TypePtr& ret_type = schema.returns ()[idx].type ();
1674- stack[stack_idx] = from_ivalue (ret_type, torch::jit::pop (ivalue_stack));
1675- }
1676- });
1677- }
1678-
16791470AOTITorchError aoti_torch_create_device_guard (
16801471 int32_t device_index,
16811472 DeviceGuardHandle* ret_guard // returns new reference
0 commit comments