Skip to content

Commit a979a52

Browse files
Add a method to resolve recursion in optional dependencies
1 parent 2ed4df2 commit a979a52

File tree

4 files changed

+295
-82
lines changed

4 files changed

+295
-82
lines changed

src/has_recursion.rs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
use indexmap::IndexMap;
2+
use pep508_rs::Requirement;
3+
use std::ops::Deref;
4+
use thiserror::Error;
5+
6+
/// A trait that resolves recursions for groups of requirements that can be mapped to `IndexMap<String, Vec<T>>`
7+
/// where T is a type that can be mapped to either a Requirement or a reference to other groups of requirements.
8+
pub trait HasRecursion<T>: Deref<Target = IndexMap<String, Vec<T>>>
9+
where
10+
T: RecursionItem,
11+
{
12+
/// Resolves the groups of requirements into flat lists of requirements.
13+
fn resolve_all(
14+
&self,
15+
name: Option<&str>,
16+
) -> Result<IndexMap<String, Vec<Requirement>>, RecursionResolutionError> {
17+
// Helper function to resolve a single group
18+
fn resolve_single<'a, T: RecursionItem>(
19+
groups: &'a IndexMap<String, Vec<T>>,
20+
group: &'a str,
21+
resolved: &mut IndexMap<String, Vec<Requirement>>,
22+
parents: &mut Vec<&'a str>,
23+
name: Option<&'a str>,
24+
) -> Result<(), RecursionResolutionError> {
25+
let Some(items) = groups.get(group) else {
26+
// If the group included in another group does not exist, return an error
27+
let parent = parents.iter().last().expect("should have a parent");
28+
return Err(RecursionResolutionError::GroupNotFound(
29+
T::group_name(),
30+
group.to_string(),
31+
parent.to_string(),
32+
));
33+
};
34+
// If there is a cycle in dependency groups, return an error
35+
if parents.contains(&group) {
36+
return Err(RecursionResolutionError::DependencyGroupCycle(
37+
T::table_name(),
38+
Cycle(parents.iter().map(|s| s.to_string()).collect()),
39+
));
40+
}
41+
// If the group has already been resolved, exit early
42+
if resolved.get(group).is_some() {
43+
return Ok(());
44+
}
45+
// Otherwise, perform recursion, as required, on the dependency group's specifiers
46+
parents.push(group);
47+
let mut requirements = Vec::with_capacity(items.len());
48+
for spec in items.iter() {
49+
match spec.parse(name) {
50+
// It's a requirement. Just add it to the Vec of resolved requirements
51+
Item::Requirement(requirement) => requirements.push(requirement.clone()),
52+
// It's a reference to other groups. Recurse into them
53+
Item::Groups(inner_groups) => {
54+
for group in inner_groups {
55+
resolve_single(groups, group, resolved, parents, name)?;
56+
requirements.extend(resolved.get(group).into_iter().flatten().cloned());
57+
}
58+
}
59+
}
60+
}
61+
// Add the resolved group to IndexMap
62+
resolved.insert(group.to_string(), requirements.clone());
63+
parents.pop();
64+
Ok(())
65+
}
66+
67+
let mut resolved = IndexMap::new();
68+
for group in self.keys() {
69+
resolve_single(self, group, &mut resolved, &mut Vec::new(), name)?;
70+
}
71+
Ok(resolved)
72+
}
73+
}
74+
/// A trait that defines how to parse a recursion item.
75+
pub trait RecursionItem {
76+
/// Parse the item into a requirement or a reference to other groups.
77+
fn parse<'a>(&'a self, name: Option<&str>) -> Item<'a>;
78+
/// The name of the group in the TOML file.
79+
fn group_name() -> String;
80+
/// The name of the table in the TOML file.
81+
fn table_name() -> String;
82+
}
83+
84+
pub enum Item<'a> {
85+
Requirement(Requirement),
86+
Groups(Vec<&'a str>),
87+
}
88+
89+
#[derive(Debug, Error)]
90+
pub enum RecursionResolutionError {
91+
#[error("Failed to find {0} `{1}` included by `{2}`")]
92+
GroupNotFound(String, String, String),
93+
#[error("Detected a cycle in `{0}`: {1}")]
94+
DependencyGroupCycle(String, Cycle),
95+
}
96+
97+
/// A cycle in the recursion.
98+
#[derive(Debug)]
99+
pub struct Cycle(Vec<String>);
100+
101+
/// Display a cycle, e.g., `a -> b -> c -> a`.
102+
impl std::fmt::Display for Cycle {
103+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104+
let [first, rest @ ..] = self.0.as_slice() else {
105+
return Ok(());
106+
};
107+
write!(f, "`{first}`")?;
108+
for group in rest {
109+
write!(f, " -> `{group}`")?;
110+
}
111+
write!(f, " -> `{first}`")?;
112+
Ok(())
113+
}
114+
}

src/lib.rs

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ mod pep639_glob;
44
#[cfg(feature = "pep639-glob")]
55
pub use pep639_glob::{parse_pep639_glob, Pep639GlobError};
66

7+
pub mod has_recursion;
8+
pub mod optional_dependencies_resolve;
79
pub mod pep735_resolve;
810

911
use indexmap::IndexMap;
@@ -83,7 +85,7 @@ pub struct Project {
8385
/// Project dependencies
8486
pub dependencies: Option<Vec<Requirement>>,
8587
/// Optional dependencies
86-
pub optional_dependencies: Option<IndexMap<String, Vec<Requirement>>>,
88+
pub optional_dependencies: Option<OptionalDependencies>,
8789
/// Specifies which fields listed by PEP 621 were intentionally unspecified
8890
/// so another tool can/will provide such metadata dynamically.
8991
pub dynamic: Option<Vec<String>>,
@@ -115,6 +117,23 @@ impl Project {
115117
}
116118
}
117119

120+
/// The `[project.optional-dependencies]` section of pyproject.toml
121+
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
122+
#[serde(transparent)]
123+
pub struct OptionalDependencies {
124+
pub inner: IndexMap<String, Vec<Requirement>>,
125+
#[serde(skip)]
126+
pub self_reference_name: Option<String>,
127+
}
128+
129+
impl Deref for OptionalDependencies {
130+
type Target = IndexMap<String, Vec<Requirement>>;
131+
132+
fn deref(&self) -> &Self::Target {
133+
&self.inner
134+
}
135+
}
136+
118137
/// The full description of the project (i.e. the README).
119138
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
120139
#[serde(rename_all = "kebab-case")]
@@ -228,7 +247,17 @@ pub enum DependencyGroupSpecifier {
228247
impl PyProjectToml {
229248
/// Parse `pyproject.toml` content
230249
pub fn new(content: &str) -> Result<Self, toml::de::Error> {
231-
toml::de::from_str(content)
250+
let mut pyproject: PyProjectToml = toml::de::from_str(content)?;
251+
252+
// Set the project name as optional-dependencies self_reference_name
253+
if let Some(project) = pyproject.project.as_mut() {
254+
let name = project.name.clone();
255+
if let Some(od) = project.optional_dependencies.as_mut() {
256+
od.self_reference_name = Some(name);
257+
}
258+
}
259+
260+
Ok(pyproject)
232261
}
233262
}
234263

@@ -337,6 +366,14 @@ tomatoes = "spam:main_tomatoes""#;
337366
project.gui_scripts.as_ref().unwrap()["spam-gui"],
338367
"spam:main_gui"
339368
);
369+
assert_eq!(
370+
project
371+
.optional_dependencies
372+
.as_ref()
373+
.unwrap()
374+
.self_reference_name,
375+
Some("spam".to_string())
376+
);
340377
}
341378

342379
#[test]
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
use crate::{
2+
has_recursion::{HasRecursion, Item, RecursionItem, RecursionResolutionError},
3+
OptionalDependencies,
4+
};
5+
use indexmap::IndexMap;
6+
use pep508_rs::Requirement;
7+
8+
impl HasRecursion<Requirement> for OptionalDependencies {}
9+
10+
impl RecursionItem for Requirement {
11+
fn parse(&self, name: Option<&str>) -> Item {
12+
if name.map(|n| n == self.name.to_string()).unwrap_or(false) {
13+
Item::Groups(self.extras.iter().map(|extra| extra.as_ref()).collect())
14+
} else {
15+
Item::Requirement(self.clone())
16+
}
17+
}
18+
fn table_name() -> String {
19+
"project.optional-dependencies".to_string()
20+
}
21+
fn group_name() -> String {
22+
"optional dependency group".to_string()
23+
}
24+
}
25+
26+
impl OptionalDependencies {
27+
/// Resolve the optional dependencies into a flat list of external requirements.
28+
///
29+
/// This function will recursively resolve all references to other groups of optional dependencies,
30+
/// It will return an error if there is a cycle in the optional
31+
/// dependencies or if a group references another group that does not exist.
32+
pub fn resolve(&self) -> Result<IndexMap<String, Vec<Requirement>>, RecursionResolutionError> {
33+
self.resolve_all(self.self_reference_name.as_deref())
34+
}
35+
}
36+
37+
#[cfg(test)]
38+
mod tests {
39+
use pep508_rs::Requirement;
40+
use std::str::FromStr;
41+
42+
use crate::PyProjectToml;
43+
44+
#[test]
45+
fn test_parse_pyproject_toml_dependency_groups_resolve() {
46+
let source = r#"[project]
47+
name = "spam"
48+
49+
[project.optional-dependencies]
50+
alpha = ["beta", "gamma", "delta"]
51+
epsilon = ["eta<2.0", "theta==2024.09.01"]
52+
iota = ["spam[alpha]"]
53+
"#;
54+
let project_toml = PyProjectToml::new(source).unwrap();
55+
let optional_dependencies = project_toml
56+
.project
57+
.as_ref()
58+
.unwrap()
59+
.optional_dependencies
60+
.as_ref()
61+
.unwrap();
62+
63+
assert_eq!(
64+
optional_dependencies.resolve().unwrap()["iota"],
65+
vec![
66+
Requirement::from_str("beta").unwrap(),
67+
Requirement::from_str("gamma").unwrap(),
68+
Requirement::from_str("delta").unwrap()
69+
]
70+
);
71+
}
72+
73+
#[test]
74+
fn test_parse_pyproject_toml_dependency_groups_cycle() {
75+
let source = r#"[project]
76+
name = "spam"
77+
78+
[project.optional-dependencies]
79+
alpha = ["spam[iota]"]
80+
iota = ["spam[alpha]"]
81+
"#;
82+
let project_toml = PyProjectToml::new(source).unwrap();
83+
let optional_dependencies = project_toml
84+
.project
85+
.as_ref()
86+
.unwrap()
87+
.optional_dependencies
88+
.as_ref()
89+
.unwrap();
90+
assert_eq!(
91+
optional_dependencies.resolve().unwrap_err().to_string(),
92+
String::from(
93+
"Detected a cycle in `project.optional-dependencies`: `alpha` -> `iota` -> `alpha`"
94+
)
95+
)
96+
}
97+
98+
#[test]
99+
fn test_parse_pyproject_toml_dependency_groups_missing_include() {
100+
let source = r#"[project]
101+
name = "spam"
102+
103+
[project.optional-dependencies]
104+
iota = ["spam[alpha]"]
105+
"#;
106+
let project_toml = PyProjectToml::new(source).unwrap();
107+
let optional_dependencies = project_toml
108+
.project
109+
.as_ref()
110+
.unwrap()
111+
.optional_dependencies
112+
.as_ref()
113+
.unwrap();
114+
assert_eq!(
115+
optional_dependencies.resolve().unwrap_err().to_string(),
116+
String::from("Failed to find optional dependency group `alpha` included by `iota`")
117+
)
118+
}
119+
}

0 commit comments

Comments
 (0)