Skip to content

Commit c0bbda3

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Move static from_ivalue/to_ivalue to new shim_common.cpp (pytorch#166373)
Move `from_ivalue` and `to_ivalue` and their dependents `StableIValueBoxedKernel`, `aoti_torch_library_impl` `aoti_torch_call_dispatcher` into new (non-aoti shim_common.cpp) This is in prep for the above PRs where I add v2s (`torch_call_dispatcher` and `torch_library_impl`) that are versioning aware Pull Request resolved: pytorch#166373 Approved by: https://github.com/janeyx99 ghstack dependencies: pytorch#164356
1 parent fefb546 commit c0bbda3

File tree

3 files changed

+219
-209
lines changed

3 files changed

+219
-209
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ inductor_core_resources = [
482482
"torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp",
483483
"torch/csrc/inductor/inductor_ops.cpp",
484484
"torch/csrc/jit/serialization/pickle.cpp",
485+
"torch/csrc/shim_common.cpp",
485486
]
486487

487488
libtorch_core_sources = sorted(

torch/csrc/inductor/aoti_torch/shim_common.cpp

Lines changed: 0 additions & 209 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,169 +1406,6 @@ AOTITorchError aoti_torch_zero_(AtenTensorHandle tensor) {
14061406
});
14071407
}
14081408

1409-
static StableIValue from_ivalue(
1410-
const c10::TypePtr& type,
1411-
const c10::IValue& ivalue) {
1412-
switch (type->kind()) {
1413-
case c10::TypeKind::TensorType: {
1414-
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
1415-
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
1416-
return torch::stable::detail::from(ath);
1417-
}
1418-
case c10::TypeKind::IntType: {
1419-
return torch::stable::detail::from(ivalue.toInt());
1420-
}
1421-
case c10::TypeKind::FloatType: {
1422-
return torch::stable::detail::from(ivalue.toDouble());
1423-
}
1424-
case c10::TypeKind::BoolType: {
1425-
return torch::stable::detail::from(ivalue.toBool());
1426-
}
1427-
case c10::TypeKind::ScalarTypeType: {
1428-
return torch::stable::detail::from(ivalue.toScalarType());
1429-
}
1430-
case c10::TypeKind::DeviceObjType: {
1431-
return torch::stable::detail::from(ivalue.toDevice());
1432-
}
1433-
case c10::TypeKind::LayoutType: {
1434-
return torch::stable::detail::from(ivalue.toLayout());
1435-
}
1436-
case c10::TypeKind::MemoryFormatType: {
1437-
return torch::stable::detail::from(ivalue.toMemoryFormat());
1438-
}
1439-
case c10::TypeKind::OptionalType: {
1440-
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
1441-
1442-
// ideally, if we had the C++ type corresponding to inner_type, which we
1443-
// will denote as inner_type::t (does not actually exist), we would be
1444-
// able to follow the patterned semantic of every other case here in one
1445-
// line:
1446-
//
1447-
// return
1448-
// torch::stable::detail::from<std::optional<inner_type::t>>(ivalue.toInnerTypeT()));
1449-
//
1450-
// BUT we do NOT have that type inner_type::t readily available, so we
1451-
// will manually unwrap and recursively call. This implementation MUST
1452-
// be kept in sync with torch::stable::detail::from<std::optional<T>>
1453-
// function in torch/csrc/stable/stableivalue_conversions.h
1454-
if (ivalue.isNone()) {
1455-
return torch::stable::detail::from(std::nullopt);
1456-
}
1457-
StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue));
1458-
return torch::stable::detail::from(sivp);
1459-
}
1460-
default: {
1461-
TORCH_CHECK(
1462-
false,
1463-
"Not yet supported conversion from IValue to StableIValue for schema type: ",
1464-
type->str());
1465-
}
1466-
}
1467-
}
1468-
1469-
static c10::IValue to_ivalue(
1470-
const c10::TypePtr& type,
1471-
const StableIValue stable_ivalue) {
1472-
switch (type->kind()) {
1473-
case c10::TypeKind::TensorType: {
1474-
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
1475-
torch::stable::detail::to<AtenTensorHandle>(stable_ivalue));
1476-
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
1477-
ret_raiiath.get())));
1478-
}
1479-
case c10::TypeKind::IntType: {
1480-
return c10::IValue(torch::stable::detail::to<int64_t>(stable_ivalue));
1481-
}
1482-
case c10::TypeKind::FloatType: {
1483-
return c10::IValue(torch::stable::detail::to<double>(stable_ivalue));
1484-
}
1485-
case c10::TypeKind::BoolType: {
1486-
return c10::IValue(torch::stable::detail::to<bool>(stable_ivalue));
1487-
}
1488-
case c10::TypeKind::ScalarTypeType: {
1489-
return c10::IValue(
1490-
torch::stable::detail::to<c10::ScalarType>(stable_ivalue));
1491-
}
1492-
case c10::TypeKind::DeviceObjType: {
1493-
return c10::IValue(torch::stable::detail::to<c10::Device>(stable_ivalue));
1494-
}
1495-
case c10::TypeKind::LayoutType: {
1496-
return c10::IValue(torch::stable::detail::to<c10::Layout>(stable_ivalue));
1497-
}
1498-
case c10::TypeKind::MemoryFormatType: {
1499-
return c10::IValue(
1500-
torch::stable::detail::to<c10::MemoryFormat>(stable_ivalue));
1501-
}
1502-
case c10::TypeKind::OptionalType: {
1503-
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
1504-
1505-
// ideally, if we had the C++ type corresponding to inner_type, which we
1506-
// will denote as inner_type::t (does not actually exist), we would be
1507-
// able to follow the patterned semantic of every other case here in one
1508-
// line:
1509-
//
1510-
// return
1511-
// c10::IValue(torch::stable::detail::to<std::optional<inner_type::t>>(stable_ivalue));
1512-
//
1513-
// BUT we do NOT have that type inner_type::t readily available, so we
1514-
// will manually unwrap and recursively call. This implementation MUST
1515-
// be kept in sync with the torch::stable::detail::to<T> function in
1516-
// torch/csrc/stable/stableivalue_conversions.h
1517-
if (stable_ivalue == torch::stable::detail::from(std::nullopt)) {
1518-
return c10::IValue();
1519-
}
1520-
auto sivp = torch::stable::detail::to<StableIValue*>(stable_ivalue);
1521-
auto ival = to_ivalue(inner_type, *sivp);
1522-
delete sivp;
1523-
return ival;
1524-
}
1525-
default: {
1526-
TORCH_CHECK(
1527-
false,
1528-
"Not yet supported conversion from StableIValue to IValue for schema type: ",
1529-
type->str());
1530-
}
1531-
}
1532-
}
1533-
1534-
class StableIValueBoxedKernel : public c10::OperatorKernel {
1535-
public:
1536-
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t))
1537-
: fn_(fn) {}
1538-
1539-
void operator()(
1540-
const c10::OperatorHandle& op,
1541-
c10::DispatchKeySet keyset,
1542-
torch::jit::Stack* stack) {
1543-
const auto& schema = op.schema();
1544-
const auto num_returns = schema.returns().size();
1545-
const auto num_arguments = schema.arguments().size();
1546-
1547-
auto ministack =
1548-
std::make_unique<StableIValue[]>(std::max(num_arguments, num_returns));
1549-
1550-
for (const auto idx : c10::irange(num_arguments)) {
1551-
const auto ministack_idx = num_arguments - idx - 1;
1552-
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
1553-
ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack));
1554-
}
1555-
1556-
// boxed function is going to take a stack of StableIValues, cast them to
1557-
// our schema values, and run the function and modify the StableIValue stack
1558-
fn_(ministack.get(), num_arguments, num_returns);
1559-
1560-
// read the output from the end of the stack and wrap that back into
1561-
// IValue from StableIValue
1562-
for (size_t idx = 0; idx < num_returns; idx++) {
1563-
const c10::TypePtr& ret_type = schema.returns()[idx].type();
1564-
torch::jit::push(stack, to_ivalue(ret_type, ministack[idx]));
1565-
}
1566-
}
1567-
1568-
private:
1569-
void (*fn_)(StableIValue*, uint64_t, uint64_t);
1570-
};
1571-
15721409
AOTITorchError aoti_torch_library_init_impl(
15731410
const char* ns,
15741411
const char* k,
@@ -1618,18 +1455,6 @@ AOTITorchError aoti_torch_library_init_fragment(
16181455
});
16191456
}
16201457

1621-
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
1622-
TorchLibraryHandle self,
1623-
const char* name,
1624-
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
1625-
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1626-
reinterpret_cast<torch::Library*>(self)->impl(
1627-
name,
1628-
torch::CppFunction::makeFromBoxedFunctor(
1629-
std::make_unique<StableIValueBoxedKernel>(fn)));
1630-
});
1631-
}
1632-
16331458
AOTI_TORCH_EXPORT AOTITorchError
16341459
aoti_torch_library_def(TorchLibraryHandle self, const char* name) {
16351460
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
@@ -1642,40 +1467,6 @@ aoti_torch_delete_library_object(TorchLibraryHandle tlh) {
16421467
{ delete reinterpret_cast<torch::Library*>(tlh); });
16431468
}
16441469

1645-
AOTITorchError aoti_torch_call_dispatcher(
1646-
const char* opName,
1647-
const char* overloadName,
1648-
StableIValue* stack) {
1649-
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
1650-
const auto op =
1651-
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
1652-
const auto& schema = op.schema();
1653-
const auto num_returns = schema.returns().size();
1654-
const auto num_arguments = schema.arguments().size();
1655-
1656-
torch::jit::Stack ivalue_stack;
1657-
// we will only need max(num_args, num_returns)
1658-
ivalue_stack.reserve(std::max(num_arguments, num_returns));
1659-
1660-
// convert StableIValue stack to c10::IValue stack
1661-
for (const auto idx : c10::irange(num_arguments)) {
1662-
auto stable_ivalue = stack[idx];
1663-
auto arg_type = schema.arguments()[idx].type();
1664-
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
1665-
}
1666-
1667-
op.callBoxed(ivalue_stack);
1668-
1669-
// there should then be num_returns IValues on the stack, which
1670-
// we will convert to StableIValue and repopulate user input stack
1671-
for (const auto idx : c10::irange(num_returns)) {
1672-
const auto stack_idx = num_returns - idx - 1;
1673-
const c10::TypePtr& ret_type = schema.returns()[idx].type();
1674-
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
1675-
}
1676-
});
1677-
}
1678-
16791470
AOTITorchError aoti_torch_create_device_guard(
16801471
int32_t device_index,
16811472
DeviceGuardHandle* ret_guard // returns new reference

0 commit comments

Comments
 (0)