Skip to content

Commit 77237f0

Browse files
authored
Add class script wrapper (#2033)
* Add class code wrapper for scala 3 scripts * Move wrapping scripts to right before building * Add warning when @main is used in scripts wrapped in class * NIT Move mainClassObject def to CodeWrapper object * Add new types to enforce wrapping scripts before building project
1 parent 9dee77c commit 77237f0

File tree

27 files changed

+651
-159
lines changed

27 files changed

+651
-159
lines changed

modules/build/src/main/scala/scala/build/Build.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import scala.build.errors.*
1818
import scala.build.input.VirtualScript.VirtualScriptNameRegex
1919
import scala.build.input.*
2020
import scala.build.internal.resource.ResourceMapper
21-
import scala.build.internal.{Constants, CustomCodeWrapper, MainClass, Util}
21+
import scala.build.internal.{Constants, MainClass, Util}
2222
import scala.build.options.ScalaVersionUtil.asVersion
2323
import scala.build.options.*
2424
import scala.build.options.validation.ValidationException
@@ -227,7 +227,6 @@ object Build {
227227
CrossSources.forInputs(
228228
inputs,
229229
Sources.defaultPreprocessors(
230-
options.scriptOptions.codeWrapper.getOrElse(CustomCodeWrapper),
231230
options.archiveCache,
232231
options.internal.javaClassNameVersionOpt,
233232
() => options.javaHome().value.javaCommand
@@ -266,8 +265,11 @@ object Build {
266265
overrideOptions: BuildOptions
267266
): Either[BuildException, NonCrossBuilds] = either {
268267

269-
val baseOptions = overrideOptions.orElse(sharedOptions)
270-
val scopedSources = value(crossSources.scopedSources(baseOptions))
268+
val baseOptions = overrideOptions.orElse(sharedOptions)
269+
270+
val wrappedScriptsSources = crossSources.withWrappedScripts(baseOptions)
271+
272+
val scopedSources = value(wrappedScriptsSources.scopedSources(baseOptions))
271273

272274
val mainSources = scopedSources.sources(Scope.Main, baseOptions)
273275
val mainOptions = mainSources.buildOptions

modules/build/src/main/scala/scala/build/CrossSources.scala

Lines changed: 103 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,94 @@ import scala.build.testrunner.DynamicTestRunner.globPattern
2828
import scala.util.Try
2929
import scala.util.chaining.*
3030

31-
final case class CrossSources(
31+
/** CrossSources with unwrapped scripts, use [[withWrappedScripts]] to wrap them and obtain an
32+
* instance of CrossSources
33+
*
34+
* See [[CrossSources]] for more information
35+
*
36+
* @param paths
37+
* paths and realtive paths to sources on disk, wrapped in their build requirements
38+
* @param inMemory
39+
* in memory sources (e.g. snippets) wrapped in their build requirements
40+
* @param defaultMainClass
41+
* @param resourceDirs
42+
* @param buildOptions
43+
* build options from sources
44+
* @param unwrappedScripts
45+
* in memory script sources, their code must be wrapped before compiling
46+
*/
47+
sealed class UnwrappedCrossSources(
3248
paths: Seq[WithBuildRequirements[(os.Path, os.RelPath)]],
3349
inMemory: Seq[WithBuildRequirements[Sources.InMemory]],
3450
defaultMainClass: Option[String],
3551
resourceDirs: Seq[WithBuildRequirements[os.Path]],
36-
buildOptions: Seq[WithBuildRequirements[BuildOptions]]
52+
buildOptions: Seq[WithBuildRequirements[BuildOptions]],
53+
unwrappedScripts: Seq[WithBuildRequirements[Sources.UnwrappedScript]]
3754
) {
3855

56+
/** For all unwrapped script sources contained in this object wrap them according to provided
57+
* BuildOptions
58+
*
59+
* @param buildOptions
60+
* options used to choose the script wrapper
61+
* @return
62+
* CrossSources with all the scripts wrapped
63+
*/
64+
def withWrappedScripts(buildOptions: BuildOptions): CrossSources = {
65+
val codeWrapper = ScriptPreprocessor.getScriptWrapper(buildOptions)
66+
67+
val wrappedScripts = unwrappedScripts.map { unwrapppedWithRequirements =>
68+
unwrapppedWithRequirements.map(_.wrap(codeWrapper))
69+
}
70+
71+
CrossSources(
72+
paths,
73+
inMemory ++ wrappedScripts,
74+
defaultMainClass,
75+
resourceDirs,
76+
this.buildOptions
77+
)
78+
}
79+
3980
def sharedOptions(baseOptions: BuildOptions): BuildOptions =
4081
buildOptions
4182
.filter(_.requirements.isEmpty)
4283
.map(_.value)
4384
.foldLeft(baseOptions)(_ orElse _)
4485

45-
private def needsScalaVersion =
86+
protected def needsScalaVersion =
4687
paths.exists(_.needsScalaVersion) ||
4788
inMemory.exists(_.needsScalaVersion) ||
4889
resourceDirs.exists(_.needsScalaVersion) ||
4990
buildOptions.exists(_.needsScalaVersion)
91+
}
5092

93+
/** Information gathered from preprocessing command inputs - sources and build options from using
94+
* directives
95+
*
96+
* @param paths
97+
* paths and realtive paths to sources on disk, wrapped in their build requirements
98+
* @param inMemory
99+
* in memory sources (e.g. snippets and wrapped scripts) wrapped in their build requirements
100+
* @param defaultMainClass
101+
* @param resourceDirs
102+
* @param buildOptions
103+
* build options from sources
104+
*/
105+
final case class CrossSources(
106+
paths: Seq[WithBuildRequirements[(os.Path, os.RelPath)]],
107+
inMemory: Seq[WithBuildRequirements[Sources.InMemory]],
108+
defaultMainClass: Option[String],
109+
resourceDirs: Seq[WithBuildRequirements[os.Path]],
110+
buildOptions: Seq[WithBuildRequirements[BuildOptions]]
111+
) extends UnwrappedCrossSources(
112+
paths,
113+
inMemory,
114+
defaultMainClass,
115+
resourceDirs,
116+
buildOptions,
117+
Nil
118+
) {
51119
def scopedSources(baseOptions: BuildOptions): Either[BuildException, ScopedSources] = either {
52120

53121
val sharedOptions0 = sharedOptions(baseOptions)
@@ -114,7 +182,6 @@ final case class CrossSources(
114182
crossSources0.buildOptions.map(_.scopedValue(defaultScope))
115183
)
116184
}
117-
118185
}
119186

120187
object CrossSources {
@@ -141,7 +208,7 @@ object CrossSources {
141208
suppressWarningOptions: SuppressWarningOptions,
142209
exclude: Seq[Positioned[String]] = Nil,
143210
maybeRecoverOnError: BuildException => Option[BuildException] = e => Some(e)
144-
)(using ScalaCliInvokeData): Either[BuildException, (CrossSources, Inputs)] = either {
211+
)(using ScalaCliInvokeData): Either[BuildException, (UnwrappedCrossSources, Inputs)] = either {
145212

146213
def preprocessSources(elems: Seq[SingleElement])
147214
: Either[BuildException, Seq[PreprocessedSource]] =
@@ -262,6 +329,16 @@ object CrossSources {
262329
Sources.InMemory(m.originalPath, m.relPath, m.code, m.ignoreLen)
263330
) -> m.directivesPositions
264331
}
332+
val unwrappedScriptsWithDirectivePositions
333+
: Seq[(WithBuildRequirements[Sources.UnwrappedScript], Option[DirectivesPositions])] =
334+
preprocessedSources.collect {
335+
case m: PreprocessedSource.UnwrappedScript =>
336+
val baseReqs0 = baseReqs(m.scopePath)
337+
WithBuildRequirements(
338+
m.requirements.fold(baseReqs0)(_ orElse baseReqs0),
339+
Sources.UnwrappedScript(m.originalPath, m.relPath, m.wrapScriptFun)
340+
) -> m.directivesPositions
341+
}
265342

266343
val resourceDirs: Seq[WithBuildRequirements[os.Path]] = allInputs.elements.collect {
267344
case r: ResourceDirectory =>
@@ -271,14 +348,20 @@ object CrossSources {
271348
)
272349

273350
lazy val allPathsWithDirectivesByScope: Map[Scope, Seq[(os.Path, DirectivesPositions)]] =
274-
(pathsWithDirectivePositions ++ inMemoryWithDirectivePositions)
351+
(pathsWithDirectivePositions ++
352+
inMemoryWithDirectivePositions ++
353+
unwrappedScriptsWithDirectivePositions)
275354
.flatMap { (withBuildRequirements, directivesPositions) =>
276355
val scope = withBuildRequirements.scopedValue(Scope.Main).scope
277356
val path: os.Path = withBuildRequirements.value match
278357
case im: Sources.InMemory =>
279358
im.originalPath match
280359
case Right((_, p: os.Path)) => p
281360
case _ => inputs.workspace / im.generatedRelPath
361+
case us: Sources.UnwrappedScript =>
362+
us.originalPath match
363+
case Right((_, p: os.Path)) => p
364+
case _ => inputs.workspace / us.generatedRelPath
282365
case (p: os.Path, _) => p
283366
directivesPositions.map((path, scope, _))
284367
}
@@ -306,9 +389,20 @@ object CrossSources {
306389
}
307390
}
308391

309-
val paths = pathsWithDirectivePositions.map(_._1)
310-
val inMemory = inMemoryWithDirectivePositions.map(_._1)
311-
(CrossSources(paths, inMemory, defaultMainClassOpt, resourceDirs, buildOptions), allInputs)
392+
val paths = pathsWithDirectivePositions.map(_._1)
393+
val inMemory = inMemoryWithDirectivePositions.map(_._1)
394+
val unwrappedScripts = unwrappedScriptsWithDirectivePositions.map(_._1)
395+
(
396+
UnwrappedCrossSources(
397+
paths,
398+
inMemory,
399+
defaultMainClassOpt,
400+
resourceDirs,
401+
buildOptions,
402+
unwrappedScripts
403+
),
404+
allInputs
405+
)
312406
}
313407

314408
private def resolveInputsFromSources(sources: Seq[Positioned[os.Path]], enableMarkdown: Boolean) =

modules/build/src/main/scala/scala/build/Sources.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,17 @@ object Sources {
7373
topWrapperLen: Int
7474
)
7575

76+
final case class UnwrappedScript(
77+
originalPath: Either[String, (os.SubPath, os.Path)],
78+
generatedRelPath: os.RelPath,
79+
wrapScriptFun: CodeWrapper => (String, Int)
80+
) {
81+
def wrap(wrapper: CodeWrapper): InMemory = {
82+
val (content, topWrapperLen) = wrapScriptFun(wrapper)
83+
InMemory(originalPath, generatedRelPath, content, topWrapperLen)
84+
}
85+
}
86+
7687
/** The default preprocessor list.
7788
*
7889
* @param codeWrapper
@@ -86,13 +97,12 @@ object Sources {
8697
* @return
8798
*/
8899
def defaultPreprocessors(
89-
codeWrapper: CodeWrapper,
90100
archiveCache: ArchiveCache[Task],
91101
javaClassNameVersionOpt: Option[String],
92102
javaCommand: () => String
93103
): Seq[Preprocessor] =
94104
Seq(
95-
ScriptPreprocessor(codeWrapper),
105+
ScriptPreprocessor,
96106
MarkdownPreprocessor,
97107
JavaPreprocessor(archiveCache, javaClassNameVersionOpt, javaCommand),
98108
ScalaPreprocessor,

modules/build/src/main/scala/scala/build/bsp/BspClient.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package scala.build.bsp
22

3-
import ch.epfl.scala.{bsp4j => b}
3+
import ch.epfl.scala.bsp4j as b
44

55
import java.lang.Boolean as JBoolean
66
import java.net.URI
@@ -10,6 +10,7 @@ import java.util.concurrent.{ConcurrentHashMap, ExecutorService}
1010
import scala.build.Position.File
1111
import scala.build.bsp.protocol.TextEdit
1212
import scala.build.errors.{BuildException, CompositeBuildException, Diagnostic, Severity}
13+
import scala.build.internal.util.WarningMessages
1314
import scala.build.postprocessing.LineConversion
1415
import scala.build.{BloopBuildClient, GeneratedSource, Logger}
1516
import scala.jdk.CollectionConverters.*
@@ -48,6 +49,16 @@ class BspClient(
4849
val diag0 = diag.duplicate()
4950
diag0.getRange.getStart.setLine(startLine)
5051
diag0.getRange.getEnd.setLine(endLine)
52+
53+
if (
54+
diag0.getMessage.contains(
55+
"cannot be a main method since it cannot be accessed statically"
56+
)
57+
)
58+
diag0.setMessage(
59+
WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ false)
60+
)
61+
5162
diag0
5263
}
5364
updatedDiagOpt.getOrElse(diag)

modules/build/src/main/scala/scala/build/bsp/BspImpl.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import scala.build.errors.{
2222
ParsingInputsException
2323
}
2424
import scala.build.input.{Inputs, ScalaCliInvokeData}
25-
import scala.build.internal.{Constants, CustomCodeWrapper}
25+
import scala.build.internal.Constants
2626
import scala.build.options.{BuildOptions, Scope}
2727
import scala.collection.mutable.ListBuffer
2828
import scala.concurrent.duration.DurationInt
@@ -101,7 +101,6 @@ final class BspImpl(
101101
CrossSources.forInputs(
102102
inputs = inputs,
103103
preprocessors = Sources.defaultPreprocessors(
104-
buildOptions.scriptOptions.codeWrapper.getOrElse(CustomCodeWrapper),
105104
buildOptions.archiveCache,
106105
buildOptions.internal.javaClassNameVersionOpt,
107106
() => buildOptions.javaHome().value.javaCommand
@@ -113,16 +112,21 @@ final class BspImpl(
113112
).left.map((_, Scope.Main))
114113
}
115114

115+
val wrappedScriptsSources = crossSources.withWrappedScripts(buildOptions)
116+
116117
if (verbosity >= 3)
117-
pprint.err.log(crossSources)
118+
pprint.err.log(wrappedScriptsSources)
118119

119-
val scopedSources = value(crossSources.scopedSources(buildOptions).left.map((_, Scope.Main)))
120+
val scopedSources =
121+
value(wrappedScriptsSources.scopedSources(buildOptions).left.map((_, Scope.Main)))
120122

121123
if (verbosity >= 3)
122124
pprint.err.log(scopedSources)
123125

124-
val sourcesMain = scopedSources.sources(Scope.Main, crossSources.sharedOptions(buildOptions))
125-
val sourcesTest = scopedSources.sources(Scope.Test, crossSources.sharedOptions(buildOptions))
126+
val sourcesMain =
127+
scopedSources.sources(Scope.Main, wrappedScriptsSources.sharedOptions(buildOptions))
128+
val sourcesTest =
129+
scopedSources.sources(Scope.Test, wrappedScriptsSources.sharedOptions(buildOptions))
126130

127131
if (verbosity >= 3)
128132
pprint.err.log(sourcesMain)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package scala.build.internal
2+
3+
/** Script code wrapper that solves problem of deadlocks when using threads. The code is placed in a
4+
* class instance constructor, the created object is kept in 'mainObjectCode'.script to support
5+
* running interconnected scripts using Scala CLI <br> <br> Incompatible with Scala 2 - it uses
6+
* Scala 3 feature 'export'<br> Incompatible with native JS members - the wrapper is a class
7+
*/
8+
case object ClassCodeWrapper extends CodeWrapper {
9+
private val userCodeNestingLevel = 1
10+
def apply(
11+
code: String,
12+
pkgName: Seq[Name],
13+
indexedWrapperName: Name,
14+
extraCode: String,
15+
scriptPath: String
16+
) = {
17+
val name = CodeWrapper.mainClassObject(indexedWrapperName).backticked
18+
val wrapperClassName = Name(indexedWrapperName.raw ++ "$_").backticked
19+
val mainObjectCode =
20+
AmmUtil.normalizeNewlines(s"""|object $name {
21+
| private var args$$opt0 = Option.empty[Array[String]]
22+
| def args$$set(args: Array[String]): Unit = {
23+
| args$$opt0 = Some(args)
24+
| }
25+
| def args$$opt: Option[Array[String]] = args$$opt0
26+
| def args$$: Array[String] = args$$opt.getOrElse {
27+
| sys.error("No arguments passed to this script")
28+
| }
29+
|
30+
| lazy val script = new $wrapperClassName
31+
|
32+
| def main(args: Array[String]): Unit = {
33+
| args$$set(args)
34+
| script.hashCode() // hashCode to clear scalac warning about pure expression in statement position
35+
| }
36+
|}
37+
|
38+
|export $name.script as ${indexedWrapperName.backticked}
39+
|""".stripMargin)
40+
41+
val packageDirective =
42+
if (pkgName.isEmpty) "" else s"package ${AmmUtil.encodeScalaSourcePath(pkgName)}" + "\n"
43+
44+
// indentation is important in the generated code, so we don't want scalafmt to touch that
45+
// format: off
46+
val top = AmmUtil.normalizeNewlines(s"""
47+
$packageDirective
48+
49+
50+
final class $wrapperClassName {
51+
def args = $name.args$$
52+
def scriptPath = \"\"\"$scriptPath\"\"\"
53+
""")
54+
val bottom = AmmUtil.normalizeNewlines(s"""
55+
$extraCode
56+
}
57+
58+
$mainObjectCode
59+
""")
60+
// format: on
61+
62+
(top, bottom, userCodeNestingLevel)
63+
}
64+
}

0 commit comments

Comments
 (0)