Skip to content

Commit f4ec62e

Browse files
committed
[GR-44211] Use Env#newTruffleThreadBuilder for Ruby Fibers
1 parent 68f5b26 commit f4ec62e

File tree

8 files changed

+113
-30
lines changed

8 files changed

+113
-30
lines changed

spec/tags/core/exception/top_level_tags.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,3 @@ slow:An Exception reaching the top level with a custom backtrace is printed on S
33
slow:An Exception reaching the top level the Exception#cause is printed to STDERR with backtraces
44
slow:An Exception reaching the top level kills all threads and fibers, ensure clauses are only run for threads current fibers, not for suspended fibers with ensure on the root fiber
55
slow:An Exception reaching the top level kills all threads and fibers, ensure clauses are only run for threads current fibers, not for suspended fibers with ensure on non-root fiber
6-
fails(GR-44211):An Exception reaching the top level kills all threads and fibers, ensure clauses are only run for threads current fibers, not for suspended fibers with ensure on non-root fiber

src/main/java/org/truffleruby/RubyLanguage.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,17 @@ public void invalidateTracingAssumption() {
406406
tracingAssumption = tracingCyclicAssumption.getAssumption();
407407
}
408408

409+
private boolean multiThreading = false;
410+
411+
public boolean isMultiThreaded() {
412+
return multiThreading;
413+
}
414+
415+
@Override
416+
protected void initializeMultiThreading(RubyContext context) {
417+
this.multiThreading = true;
418+
}
419+
409420
@Override
410421
protected void initializeMultipleContexts() {
411422
LOGGER.fine("initializeMultipleContexts()");

src/main/java/org/truffleruby/core/fiber/FiberManager.java

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,12 @@ public void initialize(RubyFiber fiber, boolean blocking, RubyProc block, Node c
6565
final TruffleContext truffleContext = context.getEnv().getContext();
6666

6767
context.getThreadManager().leaveAndEnter(truffleContext, currentNode, () -> {
68-
context.getThreadManager().spawnFiber(fiber, sourceSection,
69-
() -> fiberMain(context, fiber, block, currentNode));
68+
ThreadManager threadManager = context.getThreadManager();
69+
Thread thread = threadManager.createFiberJavaThread(fiber, sourceSection,
70+
() -> beforeEnter(fiber, currentNode),
71+
() -> fiberMain(context, fiber, block, currentNode),
72+
() -> afterLeave(fiber), currentNode);
73+
thread.start();
7074
waitForInitialization(context, fiber, currentNode);
7175
return BlockingAction.SUCCESS;
7276
});
@@ -90,23 +94,25 @@ public static void waitForInitialization(RubyContext context, RubyFiber fiber, N
9094

9195
private static final BranchProfile UNPROFILED = BranchProfile.getUncached();
9296

93-
private void fiberMain(RubyContext context, RubyFiber fiber, RubyProc block, Node currentNode) {
97+
private void beforeEnter(RubyFiber fiber, Node currentNode) {
9498
assert !fiber.isRootFiber() : "Root Fibers execute threadMain() and not fiberMain()";
9599
assertNotEntered("Fibers should start unentered to avoid triggering multithreading");
96100

97101
final Thread thread = Thread.currentThread();
98-
final TruffleContext truffleContext = context.getEnv().getContext();
99-
100102
start(fiber, thread);
101103

102104
// fully initialized
103105
fiber.initializedLatch.countDown();
104106

105-
final FiberMessage message = waitMessage(fiber, currentNode);
106-
fiber.rubyThread.setCurrentFiber(fiber);
107+
fiber.firstMessage = waitMessage(fiber, currentNode);
107108

108109
// enter() polls so we need the current Fiber to be set before enter()
109-
final Object prev = truffleContext.enter(currentNode);
110+
fiber.rubyThread.setCurrentFiber(fiber);
111+
}
112+
113+
private void fiberMain(RubyContext context, RubyFiber fiber, RubyProc block, Node currentNode) {
114+
final FiberMessage message = fiber.firstMessage;
115+
fiber.firstMessage = null;
110116

111117
FiberMessage lastMessage = null;
112118
try {
@@ -134,18 +140,24 @@ private void fiberMain(RubyContext context, RubyFiber fiber, RubyProc block, Nod
134140
final RuntimeException exception = ThreadManager.printInternalError(e);
135141
lastMessage = new FiberExceptionMessage(exception);
136142
} finally {
137-
final RubyFiber returnFiber = lastMessage == null ? null : getReturnFiber(fiber, currentNode, UNPROFILED);
143+
fiber.lastMessage = lastMessage;
144+
fiber.returnFiber = lastMessage == null ? null : getReturnFiber(fiber, currentNode, UNPROFILED);
138145

139146
// Perform all cleanup before resuming the parent Fiber
140147
// Make sure that other fibers notice we are dead before they gain control back
141148
fiber.status = FiberStatus.TERMINATED;
142149
// Leave context before addToMessageQueue() -> parent Fiber starts executing
143-
truffleContext.leave(currentNode, prev);
144-
cleanup(fiber, thread);
150+
}
151+
}
145152

146-
if (lastMessage != null) {
147-
addToMessageQueue(returnFiber, lastMessage);
148-
}
153+
private void afterLeave(RubyFiber fiber) {
154+
final Thread thread = Thread.currentThread();
155+
cleanup(fiber, thread);
156+
157+
if (fiber.lastMessage != null) {
158+
addToMessageQueue(fiber.returnFiber, fiber.lastMessage);
159+
fiber.returnFiber = null;
160+
fiber.lastMessage = null;
149161
}
150162
}
151163

@@ -161,7 +173,7 @@ public RubyFiber getReturnFiber(RubyFiber currentFiber, Node currentNode, Branch
161173
return previousFiber;
162174
} else {
163175

164-
if (currentFiber == rootFiber) {
176+
if (currentFiber == rootFiber) { // Note: this is always false for the fiberMain() caller
165177
errorProfile.enter();
166178
throw new RaiseException(context, context.getCoreExceptions().yieldFromRootFiberError(currentNode));
167179
}

src/main/java/org/truffleruby/core/fiber/FiberNodes.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ public abstract static class InitializeNode extends PrimitiveArrayArgumentsNode
121121
@TruffleBoundary
122122
@Specialization
123123
protected Object initialize(RubyFiber fiber, boolean blocking, RubyProc block) {
124+
if (!getContext().getEnv().isCreateThreadAllowed()) {
125+
// Because TruffleThreadBuilder#build denies it already, before the thread is even started.
126+
// The permission is called allowCreateThread, so it kind of makes sense.
127+
throw new RaiseException(getContext(),
128+
coreExceptions().securityError("fibers not allowed with allowCreateThread(false)", this));
129+
}
130+
124131
getContext().fiberManager.initialize(fiber, blocking, block, this);
125132
return nil;
126133
}

src/main/java/org/truffleruby/core/fiber/RubyFiber.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ public enum FiberStatus {
9292
boolean blocking = true;
9393
public RubyArray cGlobalVariablesDuringInitFunction;
9494

95+
// To pass state between beforeEnter(), fiberMain() and afterLeave()
96+
FiberManager.FiberMessage firstMessage;
97+
RubyFiber returnFiber;
98+
FiberManager.FiberMessage lastMessage;
99+
95100
public RubyFiber(
96101
RubyClass rubyClass,
97102
Shape shape,

src/main/java/org/truffleruby/core/thread/ThreadManager.java

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.function.Supplier;
2222

2323
import com.oracle.truffle.api.CompilerAsserts;
24+
import com.oracle.truffle.api.CompilerDirectives;
2425
import com.oracle.truffle.api.CompilerDirectives.ValueType;
2526
import com.oracle.truffle.api.TruffleContext;
2627
import com.oracle.truffle.api.TruffleOptions;
@@ -164,16 +165,29 @@ private static ThreadFactory getVirtualThreadFactory() {
164165

165166
@CompilationFinal static ThreadFactory VIRTUAL_THREAD_FACTORY = getVirtualThreadFactory();
166167

167-
private Thread createFiberJavaThread(RubyFiber fiber, SourceSection sourceSection, Runnable runnable) {
168+
public Thread createFiberJavaThread(RubyFiber fiber, SourceSection sourceSection, Runnable beforeEnter,
169+
Runnable body, Runnable afterLeave, Node node) {
168170
if (context.isPreInitializing()) {
169-
throw new UnsupportedOperationException("fibers should not be created while pre-initializing the context");
171+
throw CompilerDirectives
172+
.shouldNotReachHere("fibers should not be created while pre-initializing the context");
170173
}
171174

172175
final Thread thread;
173176
if (context.getOptions().VIRTUAL_THREAD_FIBERS) {
174-
thread = VIRTUAL_THREAD_FACTORY.newThread(runnable);
177+
thread = VIRTUAL_THREAD_FACTORY.newThread(() -> {
178+
var truffleContext = context.getEnv().getContext();
179+
beforeEnter.run();
180+
Object prev = truffleContext.enter(node);
181+
try {
182+
body.run();
183+
} finally {
184+
truffleContext.leave(node, prev);
185+
afterLeave.run();
186+
}
187+
});
175188
} else {
176-
thread = new Thread(runnable); // context.getEnv().createUnenteredThread(runnable);
189+
thread = context.getEnv().newTruffleThreadBuilder(body).beforeEnter(beforeEnter).afterLeave(afterLeave)
190+
.build();
177191
}
178192

179193
language.rubyThreadInitMap.put(thread, fiber.rubyThread);
@@ -186,15 +200,16 @@ private Thread createFiberJavaThread(RubyFiber fiber, SourceSection sourceSectio
186200
return thread;
187201
}
188202

189-
private Thread createJavaThread(Runnable runnable, RubyThread rubyThread, String info) {
203+
private Thread createJavaThread(Runnable runnable, RubyThread rubyThread, String info, Node node) {
190204
if (context.getOptions().SINGLE_THREADED) {
191205
throw new RaiseException(
192206
context,
193-
context.getCoreExceptions().securityError("threads not allowed in single-threaded mode", null));
207+
context.getCoreExceptions().securityError("threads not allowed in single-threaded mode", node));
194208
}
195209

196210
if (context.isPreInitializing()) {
197-
throw new UnsupportedOperationException("threads should not be created while pre-initializing the context");
211+
throw CompilerDirectives
212+
.shouldNotReachHere("threads should not be created while pre-initializing the context");
198213
}
199214

200215
final Thread thread = context.getEnv().newTruffleThreadBuilder(runnable).build();
@@ -235,10 +250,6 @@ private static Thread.UncaughtExceptionHandler uncaughtExceptionHandler(RubyFibe
235250
};
236251
}
237252

238-
public void spawnFiber(RubyFiber fiber, SourceSection sourceSection, Runnable task) {
239-
createFiberJavaThread(fiber, sourceSection, task).start();
240-
}
241-
242253
/** Whether the thread was created by TruffleRuby. */
243254
@TruffleBoundary
244255
public boolean isRubyManagedThread(Thread thread) {
@@ -303,7 +314,8 @@ public void initialize(RubyThread rubyThread, Node currentNode, String info, Str
303314
rubyThread.sourceLocation = info;
304315
final RubyFiber rootFiber = rubyThread.getRootFiber();
305316

306-
final Thread thread = createJavaThread(() -> threadMain(rubyThread, currentNode, task), rubyThread, info);
317+
final Thread thread = createJavaThread(() -> threadMain(rubyThread, currentNode, task), rubyThread, info,
318+
currentNode);
307319
rubyThread.thread = thread;
308320

309321
thread.start();

src/main/java/org/truffleruby/debug/TruffleDebugNodes.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,6 +1382,7 @@ protected RubyString parseName(InternalMethod method) {
13821382
/** Creates a Truffle thread which is not {@link ThreadManager#isRubyManagedThread(java.lang.Thread)}}. */
13831383
@CoreMethod(names = "create_polyglot_thread", onSingleton = true, required = 1)
13841384
public abstract static class CreatePolyglotThread extends CoreMethodArrayArgumentsNode {
1385+
@TruffleBoundary
13851386
@Specialization
13861387
protected Object parseName(Object hostRunnable) {
13871388
Runnable runnable = (Runnable) getContext().getEnv().asHostObject(hostRunnable);
@@ -1411,4 +1412,13 @@ protected long handleCount() {
14111412
}
14121413
}
14131414

1415+
@CoreMethod(names = "multithreaded?", onSingleton = true)
1416+
public abstract static class IsMultiThreadedNode extends CoreMethodArrayArgumentsNode {
1417+
1418+
@Specialization
1419+
protected boolean isMultiThreaded() {
1420+
return getLanguage().isMultiThreaded();
1421+
}
1422+
}
1423+
14141424
}

src/test/java/org/truffleruby/ContextPermissionsTest.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.junit.Test;
1717

1818
import static org.junit.Assert.assertEquals;
19+
import static org.junit.Assert.assertFalse;
1920
import static org.junit.Assert.assertTrue;
2021

2122
public class ContextPermissionsTest {
@@ -88,6 +89,7 @@ public void testThreadsNoNative() throws Throwable {
8889
Assert.assertEquals(3, context.eval("ruby", "1 + 2").asInt());
8990

9091
Assert.assertEquals(7, context.eval("ruby", "Thread.new { 3 + 4 }.value").asInt());
92+
assertTrue(context.eval("ruby", "Truffle::Debug.multithreaded?").asBoolean());
9193

9294
RubyTest.assertThrows(
9395
() -> context.eval("ruby", "File.stat('.')"),
@@ -118,8 +120,29 @@ public void testNoThreadsEnforcesSingleThreadedOption() throws Throwable {
118120
}
119121

120122
@Test
121-
public void testFiberDoesNotTriggerMultiThreading() {
123+
public void testNoThreads() {
122124
try (Context context = Context.newBuilder("ruby").allowCreateThread(false).build()) {
125+
RubyTest.assertThrows(
126+
() -> context.eval("ruby", "Thread.new {}.join"),
127+
e -> {
128+
assertEquals("threads not allowed in single-threaded mode", e.getMessage());
129+
assertEquals("SecurityError", e.getGuestObject().getMetaObject().getMetaQualifiedName());
130+
});
131+
132+
RubyTest.assertThrows(
133+
() -> context.eval("ruby", "Fiber.new {}.resume"),
134+
e -> {
135+
assertEquals("fibers not allowed with allowCreateThread(false)", e.getMessage());
136+
assertEquals("SecurityError", e.getGuestObject().getMetaObject().getMetaQualifiedName());
137+
});
138+
139+
assertFalse(context.eval("ruby", "Truffle::Debug.multithreaded?").asBoolean());
140+
}
141+
}
142+
143+
@Test
144+
public void testFiberDoesNotTriggerMultiThreading() {
145+
try (Context context = Context.newBuilder("ruby").allowCreateThread(true).build()) {
123146
final Value array = context.eval(
124147
"ruby",
125148
"a = [1]; f = Fiber.new { a << 3; Fiber.yield; a << 5 }; a << 2; f.resume; a << 4; f.resume");
@@ -128,12 +151,14 @@ public void testFiberDoesNotTriggerMultiThreading() {
128151
for (int i = 0; i < 5; i++) {
129152
assertEquals(i + 1, array.getArrayElement(i).asInt());
130153
}
154+
155+
assertFalse(context.eval("ruby", "Truffle::Debug.multithreaded?").asBoolean());
131156
}
132157
}
133158

134159
@Test
135160
public void testNestedFiberAndTerminateFiber() {
136-
try (Context context = Context.newBuilder("ruby").allowCreateThread(false).build()) {
161+
try (Context context = Context.newBuilder("ruby").allowCreateThread(true).build()) {
137162
final Value array = context.eval(
138163
"ruby",
139164
"a = []; Fiber.new { a << 1; Fiber.new { a << 2; Fiber.yield; unreachable }.resume; a << 3 }.resume");
@@ -142,6 +167,8 @@ public void testNestedFiberAndTerminateFiber() {
142167
for (int i = 0; i < 3; i++) {
143168
assertEquals(i + 1, array.getArrayElement(i).asInt());
144169
}
170+
171+
assertFalse(context.eval("ruby", "Truffle::Debug.multithreaded?").asBoolean());
145172
}
146173
}
147174

0 commit comments

Comments
 (0)