Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 52 additions & 126 deletions rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,19 @@
import lombok.Value;
import org.jspecify.annotations.Nullable;
import org.openrewrite.*;
import org.openrewrite.marker.Markers;
import org.openrewrite.python.internal.PyProjectHelper;
import org.openrewrite.python.internal.PythonDependencyExecutionContextView;
import org.openrewrite.python.marker.PythonResolutionResult;
import org.openrewrite.toml.TomlIsoVisitor;
import org.openrewrite.toml.tree.Space;
import org.openrewrite.python.trait.PythonDependencyFile;
import org.openrewrite.toml.tree.Toml;
import org.openrewrite.toml.tree.TomlRightPadded;
import org.openrewrite.toml.tree.TomlType;

import java.util.*;

import static org.openrewrite.Tree.randomId;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/**
* Add a dependency to the {@code [project].dependencies} array in pyproject.toml.
* Add a dependency to a Python project. Supports {@code pyproject.toml}
* (with scope and group targeting), {@code requirements.txt}, and {@code Pipfile}.
* When uv is available, the uv.lock file is regenerated to reflect the change.
*/
@EqualsAndHashCode(callSuper = false)
Expand All @@ -54,7 +51,9 @@ public class AddDependency extends ScanningRecipe<AddDependency.Accumulator> {
String version;

@Option(displayName = "Scope",
description = "The dependency scope to add to. Defaults to `project.dependencies`.",
description = "The dependency scope to add to. For pyproject.toml this targets a specific TOML section. " +
"For requirements files, `null` matches all files, empty string matches only `requirements.txt`, " +
"and a value like `dev` matches `requirements-dev.txt`. Defaults to `project.dependencies`.",
valid = {"project.dependencies", "project.optional-dependencies", "dependency-groups",
"tool.uv.constraint-dependencies", "tool.uv.override-dependencies"},
example = "project.dependencies",
Expand Down Expand Up @@ -90,7 +89,8 @@ public String getInstanceNameSuffix() {

@Override
public String getDescription() {
return "Add a dependency to the `[project].dependencies` array in `pyproject.toml`. " +
return "Add a dependency to a Python project. Supports `pyproject.toml` " +
"(with scope/group targeting), `requirements.txt`, and `Pipfile`. " +
"When `uv` is available, the `uv.lock` file is regenerated.";
}

Expand All @@ -105,141 +105,67 @@ public Accumulator getInitialValue(ExecutionContext ctx) {

@Override
public TreeVisitor<?, ExecutionContext> getScanner(Accumulator acc) {
return new TomlIsoVisitor<ExecutionContext>() {
return new TreeVisitor<Tree, ExecutionContext>() {
@Override
public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) {
String sourcePath = document.getSourcePath().toString();

if (sourcePath.endsWith("uv.lock")) {
PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put(
PyProjectHelper.correspondingPyprojectPath(sourcePath),
document.printAll());
return document;
public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) {
if (!(tree instanceof SourceFile)) {
return tree;
}

if (!sourcePath.endsWith("pyproject.toml")) {
return document;
stopAfterPreVisit();
SourceFile sourceFile = (SourceFile) tree;
if (tree instanceof Toml.Document && sourceFile.getSourcePath().toString().endsWith("uv.lock")) {
PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put(
PyProjectHelper.correspondingPyprojectPath(sourceFile.getSourcePath().toString()),
((Toml.Document) tree).printAll());
return tree;
}
Optional<PythonResolutionResult> resolution = document.getMarkers()
.findFirst(PythonResolutionResult.class);
if (!resolution.isPresent()) {
return document;
PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null);
if (trait == null) {
return tree;
}

PythonResolutionResult marker = resolution.get();

// Check if the dependency already exists in the target scope
if (PyProjectHelper.findDependencyInScope(marker, packageName, scope, groupName) != null) {
return document;
if (PyProjectHelper.findDependencyInScope(trait.getMarker(), packageName, scope, groupName) != null) {
return tree;
}

acc.projectsToUpdate.add(sourcePath);
return document;
acc.projectsToUpdate.add(sourceFile.getSourcePath().toString());
return tree;
}
};
}

@Override
public TreeVisitor<?, ExecutionContext> getVisitor(Accumulator acc) {
return new TomlIsoVisitor<ExecutionContext>() {
return new TreeVisitor<Tree, ExecutionContext>() {
@Override
public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) {
String sourcePath = document.getSourcePath().toString();

if (sourcePath.endsWith("pyproject.toml") && acc.projectsToUpdate.contains(sourcePath)) {
return addDependencyToPyproject(document, ctx, acc);
public @Nullable Tree preVisit(Tree tree, ExecutionContext ctx) {
if (!(tree instanceof SourceFile)) {
return tree;
}
stopAfterPreVisit();
SourceFile sourceFile = (SourceFile) tree;
String sourcePath = sourceFile.getSourcePath().toString();

if (acc.projectsToUpdate.contains(sourcePath)) {
PythonDependencyFile trait = new PythonDependencyFile.Matcher().get(getCursor()).orElse(null);
if (trait != null) {
String ver = version != null ? version : "";
Map<String, String> additions = Collections.singletonMap(packageName, ver);
PythonDependencyFile updated = trait.withAddedDependencies(additions, scope, groupName);
if (updated.getTree() != tree) {
return updated.afterModification(ctx);
}
}
}

if (sourcePath.endsWith("uv.lock")) {
Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock(document, ctx);
if (tree instanceof Toml.Document && sourcePath.endsWith("uv.lock")) {
Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock((Toml.Document) tree, ctx);
if (updatedLock != null) {
return updatedLock;
}
}

return document;
return tree;
}
};
}

private Toml.Document addDependencyToPyproject(Toml.Document document, ExecutionContext ctx, Accumulator acc) {
String pep508 = version != null ? packageName + PyProjectHelper.normalizeVersionConstraint(version) : packageName;

Toml.Document updated = (Toml.Document) new TomlIsoVisitor<ExecutionContext>() {
@Override
public Toml.Array visitArray(Toml.Array array, ExecutionContext ctx) {
Toml.Array a = super.visitArray(array, ctx);

if (!PyProjectHelper.isInsideDependencyArray(getCursor(), scope, groupName)) {
return a;
}

Toml.Literal newLiteral = new Toml.Literal(
randomId(),
Space.EMPTY,
Markers.EMPTY,
TomlType.Primitive.String,
"\"" + pep508 + "\"",
pep508
);

List<TomlRightPadded<Toml>> existingPadded = a.getPadding().getValues();
List<TomlRightPadded<Toml>> newPadded = new ArrayList<>();

// An empty TOML array [] is represented as a single Toml.Empty element
boolean isEmpty = existingPadded.size() == 1 &&
existingPadded.get(0).getElement() instanceof Toml.Empty;
if (existingPadded.isEmpty() || isEmpty) {
newPadded.add(new TomlRightPadded<>(newLiteral, Space.EMPTY, Markers.EMPTY));
} else {
// Check if the last element is Toml.Empty (trailing comma marker)
TomlRightPadded<Toml> lastPadded = existingPadded.get(existingPadded.size() - 1);
boolean hasTrailingComma = lastPadded.getElement() instanceof Toml.Empty;

if (hasTrailingComma) {
// Insert before the Empty element. The Empty's position
// stores the whitespace before ']'.
// Find the last real element to copy its prefix formatting
int lastRealIdx = existingPadded.size() - 2;
Toml lastRealElement = existingPadded.get(lastRealIdx).getElement();
Toml.Literal formattedLiteral = newLiteral.withPrefix(lastRealElement.getPrefix());

// Copy all existing elements up to (not including) the Empty
for (int i = 0; i <= lastRealIdx; i++) {
newPadded.add(existingPadded.get(i));
}
// Add new literal with empty after (comma added by printer)
newPadded.add(new TomlRightPadded<>(formattedLiteral, Space.EMPTY, Markers.EMPTY));
// Keep the Empty element for trailing comma + closing bracket whitespace
newPadded.add(lastPadded);
} else {
// No trailing comma — the last real element's after has the space before ']'
Toml lastElement = lastPadded.getElement();
// For multi-line arrays, use same prefix; for inline, use single space
Space newPrefix = lastElement.getPrefix().getWhitespace().contains("\n")
? lastElement.getPrefix()
: Space.SINGLE_SPACE;
Toml.Literal formattedLiteral = newLiteral.withPrefix(newPrefix);

// Copy all existing elements but set last one's after to empty
for (int i = 0; i < existingPadded.size() - 1; i++) {
newPadded.add(existingPadded.get(i));
}
newPadded.add(lastPadded.withAfter(Space.EMPTY));
// New element gets the after from the old last element
newPadded.add(new TomlRightPadded<>(formattedLiteral, lastPadded.getAfter(), Markers.EMPTY));
}
}

return a.getPadding().withValues(newPadded);
}
}.visitNonNull(document, ctx);

if (updated != document) {
updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx);
}

return updated;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,45 @@ public static SourceSpecs setupCfg(@Nullable String before,
return text;
}

public static SourceSpecs pipfile(@Language("toml") @Nullable String before) {
return pipfile(before, s -> {
});
}

public static SourceSpecs pipfile(@Language("toml") @Nullable String before,
Consumer<SourceSpec<Toml.Document>> spec) {
SourceSpec<Toml.Document> toml = new SourceSpec<>(
Toml.Document.class, null, PipfileParser.builder(), before,
SourceSpec.ValidateSource.noop,
ctx -> {
}
);
toml.path("Pipfile");
spec.accept(toml);
return toml;
}

public static SourceSpecs pipfile(@Language("toml") @Nullable String before,
@Language("toml") @Nullable String after) {
return pipfile(before, after, s -> {
});
}

public static SourceSpecs pipfile(@Language("toml") @Nullable String before,
@Language("toml") @Nullable String after,
Consumer<SourceSpec<Toml.Document>> spec) {
SourceSpec<Toml.Document> toml = new SourceSpec<>(
Toml.Document.class, null, PipfileParser.builder(), before,
SourceSpec.ValidateSource.noop,
ctx -> {
}
);
toml.path("Pipfile");
toml.after(s -> after);
spec.accept(toml);
return toml;
}

public static SourceSpecs python(@Language("py") @Nullable String before) {
return python(before, s -> {
});
Expand Down
Loading