Skip to content

Commit 3420f46

Browse files
thunderseethecopybara-github
authored andcommitted
Generate bindings to implementations of generic traits.
Implementations of generic traits will receive bindings if their implementation is fully monomorphic. Traits with constant parameters still will not receive bindings. PiperOrigin-RevId: 879867569
1 parent f556175 commit 3420f46

File tree

11 files changed

+597
-114
lines changed

11 files changed

+597
-114
lines changed

cc_bindings_from_rs/generate_bindings/generate_function.rs

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use quote::quote;
2929
use rustc_hir::attrs::AttributeKind;
3030
use rustc_hir::{self as hir, def::DefKind};
3131
use rustc_middle::mir::Mutability;
32-
use rustc_middle::ty::{self, Ty, TyCtxt};
32+
use rustc_middle::ty::{self, TraitRef, Ty, TyCtxt};
3333
use rustc_span::def_id::DefId;
3434
use rustc_span::symbol::Symbol;
3535
use std::collections::BTreeSet;
@@ -697,6 +697,54 @@ fn get_function_cc_name(db: &BindingsGenerator, def_id: DefId) -> Result<Ident>
697697
.context("Error formatting function name")
698698
}
699699

700+
fn format_trait_ref_for_cc<'tcx>(
701+
db: &BindingsGenerator<'tcx>,
702+
trait_ref: &TraitRef<'tcx>,
703+
) -> Result<CcSnippet<'tcx>> {
704+
let trait_name = db
705+
.symbol_canonical_name(trait_ref.def_id)
706+
.and_then(|fully_qualified_name| fully_qualified_name.format_for_cc(db).ok())
707+
.expect("Generated trait method for a trait with an invalid cc name");
708+
let mut trait_args = trait_ref.args[1..].iter().filter_map(|arg| arg.as_type()).peekable();
709+
let mut prereqs = CcPrerequisites::default();
710+
let tokens = if trait_args.peek().is_none() {
711+
quote! { #trait_name }
712+
} else {
713+
let arg_tokens = trait_args
714+
.map(|ty_arg| {
715+
Ok(db.format_ty_for_cc(ty_arg, TypeLocation::Other)?.into_tokens(&mut prereqs))
716+
})
717+
.collect::<Result<Vec<_>>>()?;
718+
quote! { #trait_name<#(#arg_tokens),*> }
719+
};
720+
Ok(CcSnippet { prereqs, tokens })
721+
}
722+
723+
fn format_trait_ref_for_rs<'tcx>(
724+
db: &BindingsGenerator<'tcx>,
725+
trait_ref: &TraitRef<'tcx>,
726+
) -> Result<TokenStream> {
727+
let trait_name = db
728+
.symbol_canonical_name(trait_ref.def_id)
729+
.map(|fully_qualified_name| fully_qualified_name.format_for_rs())
730+
.expect("Generated trait method for a trait with an invalid rs name");
731+
let mut trait_args = trait_ref.args[1..].iter().filter_map(|arg| arg.as_type()).peekable();
732+
if trait_args.peek().is_none() {
733+
Ok(quote! { #trait_name })
734+
} else {
735+
let arg_tokens = trait_args
736+
.map(|ty_arg| {
737+
let static_ty_arg = crate::generate_function_thunk::replace_all_regions_with_static(
738+
db.tcx(),
739+
ty_arg,
740+
);
741+
db.format_ty_for_rs(static_ty_arg)
742+
})
743+
.collect::<Result<Vec<_>>>()?;
744+
Ok(quote! { #trait_name<#(#arg_tokens),*> })
745+
}
746+
}
747+
700748
/// Implementation of `BindingsGenerator::generate_function`.
701749
pub fn generate_function<'tcx>(
702750
db: &BindingsGenerator<'tcx>,
@@ -946,15 +994,14 @@ pub fn generate_function<'tcx>(
946994
let decl_name = trait_ref
947995
.as_ref()
948996
.map(|trait_ref| {
949-
let trait_name = db
950-
.symbol_canonical_name(trait_ref.def_id)
951-
.and_then(|fully_qualified_name| fully_qualified_name.format_for_cc(db).ok())
952-
.expect("Generated trait method for a trait with an invalid rust name");
953997
let struct_name = struct_name
954998
.as_ref()
955999
.and_then(|fully_qualified_name| fully_qualified_name.format_for_cc(db).ok())
9561000
.expect("Generated trait method for an ADT with an invalid rust name");
957-
quote! { rs_std :: impl <#struct_name, #trait_name> :: #bracketed_decl_name }
1001+
let trait_name_with_args = format_trait_ref_for_cc(db, trait_ref)
1002+
.expect("Implementation of trait containing invalid type requested. Caller should have verified type arguments were valid.")
1003+
.into_tokens(&mut prereqs);
1004+
quote! { rs_std :: impl <#struct_name, #trait_name_with_args> :: #bracketed_decl_name }
9581005
})
9591006
.or_else(|| {
9601007
struct_name.as_ref().map(|fully_qualified_name| {
@@ -993,7 +1040,8 @@ pub fn generate_function<'tcx>(
9931040
.map(|fully_qualified_name| fully_qualified_name.format_for_rs())
9941041
.expect("Generated trait method for an ADT with an invalid rust name");
9951042
let fn_name = make_rs_ident(unqualified_rust_fn_name.as_str());
996-
quote! { <#struct_name as #trait_name>::#fn_name }
1043+
let trait_name_with_args = format_trait_ref_for_rs(db, trait_ref).expect("Implementation of trait containing invalid type requested. Caller should have verified type arguments were valid.");
1044+
quote! { <#struct_name as #trait_name_with_args>::#fn_name }
9971045
})
9981046
// Inherent method
9991047
.or_else(|| struct_name.as_ref().map(|struct_name| {

cc_bindings_from_rs/generate_bindings/generate_struct_and_union.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use std::collections::{BTreeSet, HashMap, HashSet};
3838
use std::iter::once;
3939
use std::rc::Rc;
4040

41-
fn has_type_or_const_vars() -> TypeFlags {
41+
pub(crate) fn has_type_or_const_vars() -> TypeFlags {
4242
TypeFlags::HAS_TY_PARAM
4343
| TypeFlags::HAS_CT_PARAM
4444
| TypeFlags::HAS_TY_INFER

cc_bindings_from_rs/generate_bindings/lib.rs

Lines changed: 115 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use crate::generate_function::{generate_function, must_use_attr_of};
3333
use crate::generate_function_thunk::{generate_trait_thunks, TraitThunks};
3434
use crate::generate_struct_and_union::{
3535
adt_needs_bindings, cpp_enum_cpp_underlying_type, from_trait_impls_by_argument, generate_adt,
36-
generate_adt_core, generate_associated_item, scalar_value_to_string,
36+
generate_adt_core, generate_associated_item, has_type_or_const_vars, scalar_value_to_string,
3737
};
3838
use arc_anyhow::{Context, Error, Result};
3939
use code_gen_utils::{format_cc_includes, CcConstQualifier, CcInclude, NamespaceQualifier};
@@ -58,7 +58,7 @@ use rustc_abi::{AddressSpace, BackendRepr, Integer, Primitive, Scalar};
5858
use rustc_hir::def::{DefKind, Res};
5959
use rustc_middle::metadata::{ModChild, Reexport};
6060
use rustc_middle::mir::ConstValue;
61-
use rustc_middle::ty::{self, Ty, TyCtxt};
61+
use rustc_middle::ty::{self, GenericParamDefKind, Ty, TyCtxt};
6262
use rustc_span::def_id::{CrateNum, DefId, LOCAL_CRATE};
6363
use rustc_span::symbol::{sym, Symbol};
6464
use std::cmp::Ordering;
@@ -862,12 +862,13 @@ fn supported_traits(db: &BindingsGenerator<'_>) -> Rc<[DefId]> {
862862
&& crate_name.as_str() != "alloc";
863863

864864
let generics = tcx.generics_of(*trait_id);
865-
// TODO: b/259749095 - Support generics in Traits.
866-
// Traits will have a single parameter for the self type which is allowed.
867-
let no_generic_args = (generics.has_self
868-
&& generics.own_params.iter().filter(|param| param.kind.is_ty_or_const()).count()
869-
== 1)
870-
|| !generics.requires_monomorphization(tcx);
865+
// Traits do not support const generics.
866+
let no_generic_args = generics
867+
.own_params
868+
.iter()
869+
.filter(|param| matches!(param.kind, GenericParamDefKind::Const { .. }))
870+
.count()
871+
== 0;
871872

872873
let is_exposed_trait = db.symbol_canonical_name(*trait_id).is_some();
873874
// We might want to explicitly omit certain marker traits here that are already handled by the bindings for structs/enums (Copy, Clone, Default, etc.).
@@ -896,12 +897,33 @@ fn generate_trait<'tcx>(
896897
let rs_type = canonical_name.format_for_rs().to_string();
897898
let attributes = vec![quote! {CRUBIT_INTERNAL_RUST_TYPE(#rs_type)}];
898899

900+
let tcx = db.tcx();
901+
let generics = tcx.generics_of(trait_id);
902+
let own_params: Vec<_> = generics
903+
.own_params
904+
.iter()
905+
.filter(|param| matches!(param.kind, GenericParamDefKind::Type { .. }))
906+
.collect();
907+
let trait_params = if generics.has_self { &own_params[1..] } else { &own_params[..] };
908+
909+
let (template_prefix, trait_name_with_args) = if trait_params.is_empty() {
910+
(quote! {}, quote! { #trait_name })
911+
} else {
912+
let template_params = trait_params.iter().enumerate().map(|(i, _)| {
913+
let param_name = format_ident!("T{}", i);
914+
quote! { typename #param_name }
915+
});
916+
let template_args = trait_params.iter().enumerate().map(|(i, _)| format_ident!("T{}", i));
917+
(quote! { template <#(#template_params),*> }, quote! { #trait_name<#(#template_args),*> })
918+
};
919+
899920
let main_api = CcSnippet::with_include(
900921
quote! {
901922
__NEWLINE__ #doc_comment
923+
#template_prefix
902924
struct #(#attributes)* #trait_name {
903925
template <typename T>
904-
using impl = rs_std::impl<T, #trait_name>;
926+
using impl = rs_std::impl<T, #trait_name_with_args>;
905927
};
906928
__NEWLINE__
907929
},
@@ -1659,62 +1681,95 @@ fn generate_trait_impls<'a, 'tcx>(
16591681
.map(move |impl_def_id| (adt_cc_name.clone(), trait_def_id, impl_def_id))
16601682
})
16611683
})
1662-
.map(move |(adt_cc_name, trait_def_id, impl_def_id)| {
1663-
let canonical_name = db.symbol_canonical_name(trait_def_id).expect(
1664-
"symbol_canonical_name was unexpectedly called on a trait without a canonical name",
1665-
);
1666-
let trait_name = canonical_name.format_for_cc(db).map_err(|err| (impl_def_id, err))?;
1667-
let mut prereqs = CcPrerequisites::default();
1668-
if trait_def_id.krate == db.source_crate_num() {
1669-
prereqs.defs.insert(trait_def_id);
1670-
} else {
1671-
let other_crate_name = tcx.crate_name(trait_def_id.krate);
1672-
let crate_name_to_include_paths = db.crate_name_to_include_paths();
1673-
let includes = crate_name_to_include_paths
1674-
.get(other_crate_name.as_str())
1675-
.ok_or_else(|| {
1676-
let trait_name = tcx.def_path_str(trait_def_id);
1677-
(
1678-
impl_def_id,
1679-
anyhow!(
1680-
"Trait `{trait_name}` comes from the `{other_crate_name}` crate, \
1681-
but no `--crate-header` was specified for this crate"
1682-
),
1683-
)
1684-
})?;
1685-
prereqs.includes.extend(includes.iter().cloned());
1686-
}
1684+
.map(
1685+
move |(adt_cc_name, trait_def_id, impl_def_id)| -> Result<ApiSnippets, (DefId, Error)> {
1686+
let trait_header = tcx.impl_trait_header(impl_def_id);
1687+
#[rustversion::before(2025-10-17)]
1688+
let trait_header = trait_header.expect("Trait impl should have a trait header");
1689+
let trait_ref = trait_header.trait_ref.instantiate_identity();
1690+
1691+
let canonical_name = db.symbol_canonical_name(trait_def_id).expect(
1692+
"symbol_canonical_name was unexpectedly called on a trait without a canonical name",
1693+
);
1694+
let trait_name =
1695+
canonical_name.format_for_cc(db).map_err(|err| (impl_def_id, err))?;
16871696

1688-
let mut member_function_names = HashSet::new();
1689-
let assoc_items: ApiSnippets = tcx
1690-
.associated_items(impl_def_id)
1691-
.in_definition_order()
1692-
.flat_map(|assoc_item| {
1693-
generate_associated_item(db, assoc_item, &mut member_function_names)
1694-
})
1695-
.collect();
1697+
let mut prereqs = CcPrerequisites::default();
1698+
let trait_args: Vec<_> = trait_ref
1699+
.args
1700+
.iter()
1701+
// Skip self type.
1702+
.skip(1)
1703+
.filter_map(|arg| arg.as_type())
1704+
.map(|arg| {
1705+
if arg.flags().contains(has_type_or_const_vars()) {
1706+
return Err((impl_def_id, anyhow!("Implementation of traits must specify all types to receive bindings.")));
1707+
}
1708+
db.format_ty_for_cc(arg, TypeLocation::Other)
1709+
.map(|snippet| snippet.into_tokens(&mut prereqs))
1710+
.map_err(|err| (impl_def_id, err))
1711+
})
1712+
.collect::<Result<Vec<_>, _>>()?;
16961713

1697-
let main_api = assoc_items.main_api.into_tokens(&mut prereqs);
1698-
prereqs.includes.insert(db.support_header("rs_std/traits.h"));
1714+
let type_args = if trait_args.is_empty() {
1715+
quote! {}
1716+
} else {
1717+
quote! { <#(#trait_args),*> }
1718+
};
16991719

1700-
Ok(ApiSnippets {
1701-
main_api: CcSnippet {
1702-
tokens: quote! {
1703-
__NEWLINE__
1704-
template<>
1705-
struct rs_std::impl<#adt_cc_name, #trait_name> {
1706-
static constexpr bool kIsImplemented = true;
1720+
let trait_name_with_args = quote! { #trait_name #type_args };
17071721

1708-
#main_api
1709-
};
1710-
__NEWLINE__
1722+
if trait_def_id.krate == db.source_crate_num() {
1723+
prereqs.defs.insert(trait_def_id);
1724+
} else {
1725+
let other_crate_name = tcx.crate_name(trait_def_id.krate);
1726+
let crate_name_to_include_paths = db.crate_name_to_include_paths();
1727+
let includes = crate_name_to_include_paths
1728+
.get(other_crate_name.as_str())
1729+
.ok_or_else(|| {
1730+
let trait_name = tcx.def_path_str(trait_def_id);
1731+
(
1732+
impl_def_id,
1733+
anyhow!(
1734+
"Trait `{trait_name}` comes from the `{other_crate_name}` crate, \
1735+
but no `--crate-header` was specified for this crate"
1736+
),
1737+
)
1738+
})?;
1739+
prereqs.includes.extend(includes.iter().cloned());
1740+
}
1741+
1742+
let mut member_function_names = HashSet::new();
1743+
let assoc_items: ApiSnippets = tcx
1744+
.associated_items(impl_def_id)
1745+
.in_definition_order()
1746+
.flat_map(|assoc_item| {
1747+
generate_associated_item(db, assoc_item, &mut member_function_names)
1748+
})
1749+
.collect();
1750+
1751+
let main_api = assoc_items.main_api.into_tokens(&mut prereqs);
1752+
prereqs.includes.insert(db.support_header("rs_std/traits.h"));
1753+
1754+
Ok(ApiSnippets {
1755+
main_api: CcSnippet {
1756+
tokens: quote! {
1757+
__NEWLINE__
1758+
template<>
1759+
struct rs_std::impl<#adt_cc_name, #trait_name_with_args> {
1760+
static constexpr bool kIsImplemented = true;
1761+
1762+
#main_api
1763+
};
1764+
__NEWLINE__
1765+
},
1766+
prereqs,
17111767
},
1712-
prereqs,
1713-
},
1714-
cc_details: assoc_items.cc_details,
1715-
rs_details: assoc_items.rs_details,
1716-
})
1717-
})
1768+
cc_details: assoc_items.cc_details,
1769+
rs_details: assoc_items.rs_details,
1770+
})
1771+
},
1772+
)
17181773
.map(|results_snippets| {
17191774
results_snippets.unwrap_or_else(|(def_id, err)| {
17201775
generate_unsupported_def(db, def_id, err).into_main_api()

cc_bindings_from_rs/test/traits/BUILD

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,37 @@ crubit_cc_test(
5050
"//testing/base/public:gunit_main",
5151
],
5252
)
53+
54+
rust_library(
55+
name = "generic_traits",
56+
srcs = ["generic_traits.rs"],
57+
aspect_hints = [
58+
"//features:experimental",
59+
],
60+
proc_macro_deps = [
61+
"//support:crubit_annotate",
62+
],
63+
)
64+
65+
cc_bindings_from_rust(
66+
name = "generic_traits_cc_api",
67+
testonly = 1,
68+
crate = ":generic_traits",
69+
)
70+
71+
golden_test(
72+
name = "generic_traits_golden_test",
73+
basename = "generic_traits",
74+
golden_h = "generic_traits_cc_api.h",
75+
golden_rs = "generic_traits_cc_api_impl.rs",
76+
rust_library = "generic_traits",
77+
)
78+
79+
crubit_cc_test(
80+
name = "generic_traits_test",
81+
srcs = ["generic_traits_test.cc"],
82+
deps = [
83+
":generic_traits_cc_api",
84+
"//testing/base/public:gunit_main",
85+
],
86+
)

0 commit comments

Comments
 (0)