@@ -14,6 +14,7 @@ package org.assertj.generator.gradle.tasks
1414
1515import com.google.common.collect.Sets
1616import com.google.common.reflect.TypeToken
17+ import org.assertj.assertions.generator.AssertionsEntryPointType
1718import org.assertj.assertions.generator.BaseAssertionGenerator
1819import org.assertj.assertions.generator.description.ClassDescription
1920import org.assertj.assertions.generator.description.converter.ClassToClassDescriptionConverter
@@ -25,6 +26,8 @@ import org.gradle.api.file.*
2526import org.gradle.api.logging.Logging
2627import org.gradle.api.model.ObjectFactory
2728import org.gradle.api.provider.ListProperty
29+ import org.gradle.api.provider.Property
30+ import org.gradle.api.provider.SetProperty
2831import org.gradle.api.tasks.*
2932import org.gradle.api.tasks.incremental.IncrementalTaskInputs
3033
@@ -37,36 +40,71 @@ import static com.google.common.collect.Sets.newLinkedHashSet
3740/**
3841 * Executes AssertJ generation against provided sources using the configured templates.
3942 */
43+ @CacheableTask
4044class AssertJGenerationTask extends SourceTask {
4145
4246 private static final logger = Logging . getLogger(AssertJGenerationTask )
4347
48+ @InputFiles
4449 @Classpath
4550 final ConfigurableFileCollection generationClasspath
4651
52+ @OutputDirectory
53+ final DirectoryProperty outputDir
54+
55+ // TODO `internal` when Konverted
56+ @Input
57+ final Property<Boolean > skip
58+
59+ // TODO `internal` when Konverted
60+ @Input
61+ final Property<Boolean > hierarchical
62+
63+ // TODO `internal` when Konverted
64+ @Input
65+ final SetProperty<AssertionsEntryPointType > entryPoints
66+
67+ // TODO `internal` when Konverted
68+ @Input
69+ @Optional
70+ final Property<String > entryPointsClassPackage
71+
72+ // TODO `internal` when Konverted
4773 @InputFiles
4874 @Classpath
4975 final FileCollection templateFiles
5076
77+ // TODO `internal` when Konverted
5178 @Input
5279 final ListProperty<SerializedTemplate > generatorTemplates
5380
54- @OutputDirectory
55- final DirectoryProperty outputDir
56-
57- private final AssertJGeneratorExtension assertJOptions
58-
5981 @Inject
6082 AssertJGenerationTask (ObjectFactory objects , SourceSet sourceSet ) {
6183 description = " Generates AssertJ assertions for the ${ sourceSet.name} sources."
6284
63- assertJOptions = sourceSet. extensions. getByType(AssertJGeneratorExtension )
85+ def assertJOptions = sourceSet. extensions. getByType(AssertJGeneratorExtension )
6486 source(sourceSet. allJava)
6587 dependsOn sourceSet. compileJavaTaskName
6688
6789 this . generationClasspath = objects. fileCollection()
6890 .from(sourceSet. runtimeClasspath)
6991
92+ this . skip = objects. property(Boolean ). tap {
93+ set(project. provider { assertJOptions. skip })
94+ }
95+
96+ this . hierarchical = objects. property(Boolean ). tap {
97+ set(project. provider { assertJOptions. hierarchical })
98+ }
99+
100+ this . entryPoints = objects. setProperty(AssertionsEntryPointType ). tap {
101+ set(project. provider { assertJOptions. entryPoints. entryPoints })
102+ }
103+
104+ this . entryPointsClassPackage = objects. property(String ). tap {
105+ set(project. provider { assertJOptions. entryPoints. classPackage })
106+ }
107+
70108 this . outputDir = assertJOptions. outputDir
71109 // TODO Make `templates.templateFiles` `internal` once `AssertJGenerationTask` is Kotlin
72110 this . templateFiles = assertJOptions. templates. templateFiles
@@ -76,7 +114,7 @@ class AssertJGenerationTask extends SourceTask {
76114
77115 @TaskAction
78116 def execute (IncrementalTaskInputs inputs ) {
79- if (assertJOptions . skip) {
117+ if (skip. get() ) {
80118 return
81119 }
82120
@@ -125,63 +163,54 @@ class AssertJGenerationTask extends SourceTask {
125163 [(classesByTypeName[it]): file]
126164 }
127165 }
128- .collectEntries()
166+ .collectEntries() as Map< TypeToken<?> , File >
129167
130168 inputClassesToFile. values(). removeAll {
131169 ! classesToGenerate. contains(it)
132170 }
133171
172+ runGeneration(classes, inputClassNames, inputClassesToFile)
173+ }
174+
175+ private def runGeneration (
176+ Set<TypeToken<?> > allClasses ,
177+ Map<File , Set<String > > inputClassNames ,
178+ Map<TypeToken<?> , File > inputClassesToFile
179+ ) {
134180 BaseAssertionGenerator generator = new BaseAssertionGenerator ()
135181 ClassToClassDescriptionConverter converter = new ClassToClassDescriptionConverter ()
136182
137183 def absOutputDir = project. rootDir. toPath(). resolve(this . outputDir. getAsFile(). get(). toPath()). toFile()
138184
139- Set<TypeToken<?> > filteredClasses = removeAssertClasses(classes )
185+ Set<TypeToken<?> > filteredClasses = removeAssertClasses(allClasses )
140186 def report = new AssertionsGeneratorReport (
141187 absOutputDir,
142188 inputClassNames. values(). flatten(). collect { it. toString() },
143- classes - filteredClasses,
189+ allClasses - filteredClasses,
144190 )
145191
146- def templates = assertJOptions . templates . generatorTemplates. get(). collect { it. maybeLoadTemplate() }. findAll()
192+ def templates = generatorTemplates. get(). collect { it. maybeLoadTemplate() }. findAll()
147193 for (template in templates) {
148194 generator. register(template)
149195 }
150196
151197 try {
152198 generator. directoryWhereAssertionFilesAreGenerated = absOutputDir
153199
154- Set<ClassDescription > classDescriptions = new LinkedHashSet<> ()
155-
156- if (assertJOptions. hierarchical) {
157- for (clazz in filteredClasses) {
158- ClassDescription classDescription = converter. convertToClassDescription(clazz)
159- File [] generatedCustomAssertionFiles = generator. generateHierarchicalCustomAssertionFor(
160- classDescription,
161- filteredClasses,
162- )
163- report. addGeneratedAssertionFiles(generatedCustomAssertionFiles)
164- classDescriptions. add(classDescription)
165- }
200+ def classDescriptions
201+ if (hierarchical. get()) {
202+ classDescriptions = generateHierarchical(converter, generator, report, filteredClasses)
166203 } else {
167- for (clazz in filteredClasses) {
168- def classDescription = converter. convertToClassDescription(clazz)
169- classDescriptions. add(classDescription)
170-
171- if (inputClassesToFile. containsKey(clazz)) {
172- File generatedCustomAssertionFile = generator. generateCustomAssertionFor(classDescription)
173- report. addGeneratedAssertionFiles(generatedCustomAssertionFile)
174- }
175- }
204+ classDescriptions = generateFlat(converter, generator, report, filteredClasses, inputClassesToFile)
176205 }
177206
178207 if (! inputClassesToFile. isEmpty()) {
179208 // only generate the entry points if there are classes that have changed (or exist..)
180- for (assertionsEntryPointType in assertJOptions . entryPoints. entryPoints ) {
209+ for (assertionsEntryPointType in entryPoints. get() ) {
181210 File assertionsEntryPointFile = generator. generateAssertionsEntryPointClassFor(
182- classDescriptions,
211+ classDescriptions. toSet() ,
183212 assertionsEntryPointType,
184- assertJOptions . entryPoints . classPackage ,
213+ entryPointsClassPackage . getOrNull() ,
185214 )
186215 report. reportEntryPointGeneration(assertionsEntryPointType, assertionsEntryPointFile)
187216 }
@@ -193,6 +222,42 @@ class AssertJGenerationTask extends SourceTask {
193222 logger. info(report. getReportContent())
194223 }
195224
225+ private static Collection<ClassDescription > generateHierarchical (
226+ ClassToClassDescriptionConverter converter ,
227+ BaseAssertionGenerator generator ,
228+ AssertionsGeneratorReport report ,
229+ Set<TypeToken<?> > classes
230+ ) {
231+ classes. collect { clazz ->
232+ def classDescription = converter. convertToClassDescription(clazz)
233+ def generatedCustomAssertionFiles = generator. generateHierarchicalCustomAssertionFor(
234+ classDescription,
235+ classes,
236+ )
237+ report. addGeneratedAssertionFiles(generatedCustomAssertionFiles)
238+ classDescription
239+ }
240+ }
241+
242+ private static Collection<ClassDescription > generateFlat (
243+ ClassToClassDescriptionConverter converter ,
244+ BaseAssertionGenerator generator ,
245+ AssertionsGeneratorReport report ,
246+ Set<TypeToken<?> > classes ,
247+ Map<TypeToken<?> , File > inputClassesToFile
248+ ) {
249+ classes. collect { clazz ->
250+ def classDescription = converter. convertToClassDescription(clazz)
251+
252+ if (inputClassesToFile. containsKey(clazz)) {
253+ def generatedCustomAssertionFile = generator. generateCustomAssertionFor(classDescription)
254+ report. addGeneratedAssertionFiles(generatedCustomAssertionFile)
255+ }
256+
257+ classDescription
258+ }
259+ }
260+
196261 /**
197262 * Returns the source for this task, after the include and exclude patterns have been applied. Ignores source files which do not exist.
198263 *
0 commit comments