Skip to content

Commit 1191fdc

Browse files
Add a method to resolve recursion in optional dependencies
1 parent 2ed4df2 commit 1191fdc

File tree

4 files changed

+288
-87
lines changed

4 files changed

+288
-87
lines changed

src/has_recursion.rs

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
use indexmap::IndexMap;
2+
use pep508_rs::Requirement;
3+
use std::ops::Deref;
4+
use thiserror::Error;
5+
6+
/// A trait that resolves recursions for groups of requirements that can be mapped to `IndexMap<String, Vec<T>>`
7+
/// where T is a type that can be mapped to either a Requirement or a reference to other groups of requirements.
8+
pub trait HasRecursion<T>: Deref<Target = IndexMap<String, Vec<T>>>
9+
where
10+
T: RecursionItem,
11+
{
12+
/// Resolve the groups into lists of requirements.
13+
///
14+
/// This function will recursively resolve all groups, including those that
15+
/// reference other groups. It will return an error if there is a cycle in the
16+
/// groups or if a group references another group that does not exist.
17+
fn resolve(&self) -> Result<IndexMap<String, Vec<Requirement>>, RecursionResolutionError> {
18+
self.resolve_all(None)
19+
}
20+
21+
/// Resolves the groups of requirements into flat lists of requirements.
22+
fn resolve_all(
23+
&self,
24+
name: Option<&str>,
25+
) -> Result<IndexMap<String, Vec<Requirement>>, RecursionResolutionError> {
26+
// Helper function to resolve a single group
27+
fn resolve_single<'a, T: RecursionItem>(
28+
groups: &'a IndexMap<String, Vec<T>>,
29+
group: &'a str,
30+
resolved: &mut IndexMap<String, Vec<Requirement>>,
31+
parents: &mut Vec<&'a str>,
32+
name: Option<&'a str>,
33+
) -> Result<(), RecursionResolutionError> {
34+
let Some(items) = groups.get(group) else {
35+
// If the group included in another group does not exist, return an error
36+
let parent = parents.iter().last().expect("should have a parent");
37+
return Err(RecursionResolutionError::GroupNotFound(
38+
T::group_name(),
39+
group.to_string(),
40+
parent.to_string(),
41+
));
42+
};
43+
// If there is a cycle in dependency groups, return an error
44+
if parents.contains(&group) {
45+
return Err(RecursionResolutionError::DependencyGroupCycle(
46+
T::table_name(),
47+
Cycle(parents.iter().map(|s| s.to_string()).collect()),
48+
));
49+
}
50+
// If the group has already been resolved, exit early
51+
if resolved.get(group).is_some() {
52+
return Ok(());
53+
}
54+
// Otherwise, perform recursion, as required, on the dependency group's specifiers
55+
parents.push(group);
56+
let mut requirements = Vec::with_capacity(items.len());
57+
for spec in items.iter() {
58+
match spec.parse(name) {
59+
// It's a requirement. Just add it to the Vec of resolved requirements
60+
Item::Requirement(requirement) => requirements.push(requirement.clone()),
61+
// It's a reference to other groups. Recurse into them
62+
Item::Groups(inner_groups) => {
63+
for group in inner_groups {
64+
resolve_single(groups, group, resolved, parents, name)?;
65+
requirements.extend(resolved.get(group).into_iter().flatten().cloned());
66+
}
67+
}
68+
}
69+
}
70+
// Add the resolved group to IndexMap
71+
resolved.insert(group.to_string(), requirements.clone());
72+
parents.pop();
73+
Ok(())
74+
}
75+
76+
let mut resolved = IndexMap::new();
77+
for group in self.keys() {
78+
resolve_single(self, group, &mut resolved, &mut Vec::new(), name)?;
79+
}
80+
Ok(resolved)
81+
}
82+
}
83+
/// A trait that defines how to parse a recursion item.
84+
pub trait RecursionItem {
85+
/// Parse the item into a requirement or a reference to other groups.
86+
fn parse<'a>(&'a self, name: Option<&str>) -> Item<'a>;
87+
/// The name of the group in the TOML file.
88+
fn group_name() -> String;
89+
/// The name of the table in the TOML file.
90+
fn table_name() -> String;
91+
}
92+
93+
pub enum Item<'a> {
94+
Requirement(Requirement),
95+
Groups(Vec<&'a str>),
96+
}
97+
98+
#[derive(Debug, Error)]
99+
pub enum RecursionResolutionError {
100+
#[error("Failed to find {0} `{1}` included by `{2}`")]
101+
GroupNotFound(String, String, String),
102+
#[error("Detected a cycle in `{0}`: {1}")]
103+
DependencyGroupCycle(String, Cycle),
104+
}
105+
106+
/// A cycle in the recursion.
107+
#[derive(Debug)]
108+
pub struct Cycle(Vec<String>);
109+
110+
/// Display a cycle, e.g., `a -> b -> c -> a`.
111+
impl std::fmt::Display for Cycle {
112+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113+
let [first, rest @ ..] = self.0.as_slice() else {
114+
return Ok(());
115+
};
116+
write!(f, "`{first}`")?;
117+
for group in rest {
118+
write!(f, " -> `{group}`")?;
119+
}
120+
write!(f, " -> `{first}`")?;
121+
Ok(())
122+
}
123+
}

src/lib.rs

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ mod pep639_glob;
44
#[cfg(feature = "pep639-glob")]
55
pub use pep639_glob::{parse_pep639_glob, Pep639GlobError};
66

7+
pub mod has_recursion;
8+
pub mod optional_dependencies_resolve;
79
pub mod pep735_resolve;
810

911
use indexmap::IndexMap;
@@ -83,7 +85,7 @@ pub struct Project {
8385
/// Project dependencies
8486
pub dependencies: Option<Vec<Requirement>>,
8587
/// Optional dependencies
86-
pub optional_dependencies: Option<IndexMap<String, Vec<Requirement>>>,
88+
pub optional_dependencies: Option<OptionalDependencies>,
8789
/// Specifies which fields listed by PEP 621 were intentionally unspecified
8890
/// so another tool can/will provide such metadata dynamically.
8991
pub dynamic: Option<Vec<String>>,
@@ -115,6 +117,23 @@ impl Project {
115117
}
116118
}
117119

120+
/// The `[project.optional-dependencies]` section of pyproject.toml
121+
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
122+
#[serde(transparent)]
123+
pub struct OptionalDependencies {
124+
pub inner: IndexMap<String, Vec<Requirement>>,
125+
#[serde(skip)]
126+
pub self_reference_name: Option<String>,
127+
}
128+
129+
impl Deref for OptionalDependencies {
130+
type Target = IndexMap<String, Vec<Requirement>>;
131+
132+
fn deref(&self) -> &Self::Target {
133+
&self.inner
134+
}
135+
}
136+
118137
/// The full description of the project (i.e. the README).
119138
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
120139
#[serde(rename_all = "kebab-case")]
@@ -228,7 +247,17 @@ pub enum DependencyGroupSpecifier {
228247
impl PyProjectToml {
229248
/// Parse `pyproject.toml` content
230249
pub fn new(content: &str) -> Result<Self, toml::de::Error> {
231-
toml::de::from_str(content)
250+
let mut pyproject: PyProjectToml = toml::de::from_str(content)?;
251+
252+
// Set the project name as optional-dependencies self_reference_name
253+
if let Some(project) = pyproject.project.as_mut() {
254+
let name = project.name.clone();
255+
if let Some(od) = project.optional_dependencies.as_mut() {
256+
od.self_reference_name = Some(name);
257+
}
258+
}
259+
260+
Ok(pyproject)
232261
}
233262
}
234263

@@ -337,6 +366,14 @@ tomatoes = "spam:main_tomatoes""#;
337366
project.gui_scripts.as_ref().unwrap()["spam-gui"],
338367
"spam:main_gui"
339368
);
369+
assert_eq!(
370+
project
371+
.optional_dependencies
372+
.as_ref()
373+
.unwrap()
374+
.self_reference_name,
375+
Some("spam".to_string())
376+
);
340377
}
341378

342379
#[test]
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
use crate::{
2+
has_recursion::{HasRecursion, Item, RecursionItem, RecursionResolutionError},
3+
OptionalDependencies,
4+
};
5+
use indexmap::IndexMap;
6+
use pep508_rs::Requirement;
7+
8+
impl HasRecursion<Requirement> for OptionalDependencies {
9+
fn resolve(&self) -> Result<IndexMap<String, Vec<Requirement>>, RecursionResolutionError> {
10+
self.resolve_all(self.self_reference_name.as_deref())
11+
}
12+
}
13+
14+
impl RecursionItem for Requirement {
15+
fn parse(&self, name: Option<&str>) -> Item {
16+
if name.map(|n| n == self.name.to_string()).unwrap_or(false) {
17+
Item::Groups(self.extras.iter().map(|extra| extra.as_ref()).collect())
18+
} else {
19+
Item::Requirement(self.clone())
20+
}
21+
}
22+
fn table_name() -> String {
23+
"project.optional-dependencies".to_string()
24+
}
25+
fn group_name() -> String {
26+
"optional dependency group".to_string()
27+
}
28+
}
29+
#[cfg(test)]
30+
mod tests {
31+
use pep508_rs::Requirement;
32+
use std::str::FromStr;
33+
34+
use crate::PyProjectToml;
35+
36+
#[test]
37+
fn test_parse_pyproject_toml_dependency_groups_resolve() {
38+
let source = r#"[project]
39+
name = "spam"
40+
41+
[project.optional-dependencies]
42+
alpha = ["beta", "gamma", "delta"]
43+
epsilon = ["eta<2.0", "theta==2024.09.01"]
44+
iota = ["spam[alpha]"]
45+
"#;
46+
let project_toml = PyProjectToml::new(source).unwrap();
47+
let optional_dependencies = project_toml
48+
.project
49+
.as_ref()
50+
.unwrap()
51+
.optional_dependencies
52+
.as_ref()
53+
.unwrap();
54+
55+
assert_eq!(
56+
optional_dependencies.resolve().unwrap()["iota"],
57+
vec![
58+
Requirement::from_str("beta").unwrap(),
59+
Requirement::from_str("gamma").unwrap(),
60+
Requirement::from_str("delta").unwrap()
61+
]
62+
);
63+
}
64+
65+
#[test]
66+
fn test_parse_pyproject_toml_dependency_groups_cycle() {
67+
let source = r#"[project]
68+
name = "spam"
69+
70+
[project.optional-dependencies]
71+
alpha = ["spam[iota]"]
72+
iota = ["spam[alpha]"]
73+
"#;
74+
let project_toml = PyProjectToml::new(source).unwrap();
75+
let optional_dependencies = project_toml
76+
.project
77+
.as_ref()
78+
.unwrap()
79+
.optional_dependencies
80+
.as_ref()
81+
.unwrap();
82+
assert_eq!(
83+
optional_dependencies.resolve().unwrap_err().to_string(),
84+
String::from(
85+
"Detected a cycle in `project.optional-dependencies`: `alpha` -> `iota` -> `alpha`"
86+
)
87+
)
88+
}
89+
90+
#[test]
91+
fn test_parse_pyproject_toml_dependency_groups_missing_include() {
92+
let source = r#"[project]
93+
name = "spam"
94+
95+
[project.optional-dependencies]
96+
iota = ["spam[alpha]"]
97+
"#;
98+
let project_toml = PyProjectToml::new(source).unwrap();
99+
let optional_dependencies = project_toml
100+
.project
101+
.as_ref()
102+
.unwrap()
103+
.optional_dependencies
104+
.as_ref()
105+
.unwrap();
106+
assert_eq!(
107+
optional_dependencies.resolve().unwrap_err().to_string(),
108+
String::from("Failed to find optional dependency group `alpha` included by `iota`")
109+
)
110+
}
111+
}

0 commit comments

Comments
 (0)