Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion modal-go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,16 @@ var defaultProfile Profile
// clientProfile is the actual profile, from defaultProfile + InitializeClient().
var clientProfile Profile

// client is the default Modal client that talks to the control plane.
var client pb.ModalClientClient

// clients is a map of server URL to input-plane client.
var inputPlaneClients = map[string]pb.ModalClientClient{}

// authToken is the auth token received from the control plane on the first request, and sent with all
// subsequent requests to both the control plane and the input plane.
var authToken string

func init() {
defaultConfig, _ = readConfigFile()
defaultProfile = getProfile(os.Getenv("MODAL_PROFILE"))
Expand Down Expand Up @@ -108,7 +116,23 @@ func InitializeClient(options ClientOptions) error {
return err
}

// newClient dials api.modal.com with auth/timeout/retry interceptors installed.
// getOrCreateInputPlaneClient returns a client for the given server URL, creating it if it doesn't exist.
func getOrCreateInputPlaneClient(serverURL string) (pb.ModalClientClient, error) {
if client, ok := inputPlaneClients[serverURL]; ok {
return client, nil
}

profile := clientProfile
profile.ServerURL = serverURL
_, client, err := newClient(profile)
if err != nil {
return nil, err
}
inputPlaneClients[serverURL] = client
return client, nil
}

// newClient dials the given server URL with auth/timeout/retry interceptors installed.
// It returns (conn, stub). Close the conn when done.
func newClient(profile Profile) (*grpc.ClientConn, pb.ModalClientClient, error) {
var target string
Expand All @@ -131,6 +155,7 @@ func newClient(profile Profile) (*grpc.ClientConn, pb.ModalClientClient, error)
grpc.MaxCallSendMsgSize(maxMessageSize),
),
grpc.WithChainUnaryInterceptor(
authTokenInterceptor(),
retryInterceptor(),
timeoutInterceptor(),
),
Expand All @@ -157,6 +182,37 @@ func clientContext(ctx context.Context) (context.Context, error) {
), nil
}

// authTokenInterceptor handles sending and receiving the "x-modal-auth-token" header.
// We receive an auth token from the control plane on our first request. We then include that auth token in every
// subsequent request to both the control plane and the input plane.
func authTokenInterceptor() grpc.UnaryClientInterceptor {
return func(
ctx context.Context,
method string,
req, reply any,
cc *grpc.ClientConn,
inv grpc.UnaryInvoker,
opts ...grpc.CallOption,
) error {
var headers, trailers metadata.MD
// Add authToken to outgoing context if it's set
if authToken != "" {
ctx = metadata.AppendToOutgoingContext(ctx, "x-modal-auth-token", authToken)
}
opts = append(opts, grpc.Header(&headers), grpc.Trailer(&trailers))
err := inv(ctx, method, req, reply, cc, opts...)
// If we're talking to the control plane, and no auth token was sent, it will return one.
// The python server returns it in the trailers, the worker returns it in the headers.
if val, ok := headers["x-modal-auth-token"]; ok {
authToken = val[0]
} else if val, ok := trailers["x-modal-auth-token"]; ok {
authToken = val[0]
}

return err
}
}

func timeoutInterceptor() grpc.UnaryClientInterceptor {
return func(
ctx context.Context,
Expand Down
25 changes: 20 additions & 5 deletions modal-go/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ func timeNowSeconds() float64 {

// Function references a deployed Modal Function.
type Function struct {
FunctionId string
MethodName *string // used for class methods
ctx context.Context
FunctionId string
MethodName *string // used for class methods
InputPlaneURL *string // if nil, use control plane
ctx context.Context
}

// FunctionLookup looks up an existing Function.
Expand Down Expand Up @@ -64,7 +65,13 @@ func FunctionLookup(ctx context.Context, appName string, name string, options *L
return nil, err
}

return &Function{FunctionId: resp.GetFunctionId(), ctx: ctx}, nil
var inputPlaneUrl *string
if meta := resp.GetHandleMetadata(); meta != nil {
if url := meta.GetInputPlaneUrl(); url != "" {
inputPlaneUrl = &url
}
}
return &Function{FunctionId: resp.GetFunctionId(), InputPlaneURL: inputPlaneUrl, ctx: ctx}, nil
}

// Serialize Go data types to the Python pickle format.
Expand Down Expand Up @@ -122,7 +129,7 @@ func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) {
if err != nil {
return nil, err
}
invocation, err := createControlPlaneInvocation(f.ctx, f.FunctionId, input, pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC)
invocation, err := f.createRemoteInvocation(input)
if err != nil {
return nil, err
}
Expand All @@ -144,6 +151,14 @@ func (f *Function) Remote(args []any, kwargs map[string]any) (any, error) {
}
}

// createRemoteInvocation creates an Invocation using either the input plane or control plane.
func (f *Function) createRemoteInvocation(input *pb.FunctionInput) (invocation, error) {
if f.InputPlaneURL != nil {
return createInputPlaneInvocation(f.ctx, *f.InputPlaneURL, f.FunctionId, input)
}
return createControlPlaneInvocation(f.ctx, f.FunctionId, input, pb.FunctionCallInvocationType_FUNCTION_CALL_INVOCATION_TYPE_SYNC)
}

// Spawn starts running a single input on a remote function.
func (f *Function) Spawn(args []any, kwargs map[string]any) (*FunctionCall, error) {
input, err := f.createInput(args, kwargs)
Expand Down
122 changes: 106 additions & 16 deletions modal-go/invocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ import (
"google.golang.org/protobuf/proto"
)

type invocation interface {
awaitOutput(timeout *time.Duration) (any, error)
retry(retryCount uint32) error
}

// controlPlaneInvocation implements the invocation interface.
type controlPlaneInvocation struct {
FunctionCallId string
Expand Down Expand Up @@ -52,7 +57,7 @@ func controlPlaneInvocationFromFunctionCallId(ctx context.Context, functionCallI
}

func (c *controlPlaneInvocation) awaitOutput(timeout *time.Duration) (any, error) {
return pollFunctionOutput(c.ctx, c.FunctionCallId, timeout)
return pollFunctionOutput(c.ctx, c.getOutput, timeout)
}

func (c *controlPlaneInvocation) retry(retryCount uint32) error {
Expand All @@ -75,8 +80,102 @@ func (c *controlPlaneInvocation) retry(retryCount uint32) error {
return nil
}

// Poll for outputs for a given FunctionCall ID.
func pollFunctionOutput(ctx context.Context, functionCallId string, timeout *time.Duration) (any, error) {
// getOutput fetches the output for the current function call with a timeout in milliseconds.
func (c *controlPlaneInvocation) getOutput(timeout time.Duration) (*pb.FunctionGetOutputsItem, error) {
response, err := client.FunctionGetOutputs(c.ctx, pb.FunctionGetOutputsRequest_builder{
FunctionCallId: c.FunctionCallId,
MaxValues: 1,
Timeout: float32(timeout.Seconds()),
LastEntryId: "0-0",
ClearOnSuccess: true,
RequestedAt: timeNowSeconds(),
}.Build())
if err != nil {
return nil, fmt.Errorf("FunctionGetOutputs failed: %w", err)
}
outputs := response.GetOutputs()
if len(outputs) > 0 {
return outputs[0], nil
}
return nil, nil
}

// InputPlaneInvocation implements the Invocation interface for the input plane.
type inputPlaneInvocation struct {
client pb.ModalClientClient
functionId string
input *pb.FunctionPutInputsItem
attemptToken string
ctx context.Context
}

// CreateInputPlaneInvocation creates a new InputPlaneInvocation by starting an attempt.
func createInputPlaneInvocation(ctx context.Context, inputPlaneURL string, functionId string, input *pb.FunctionInput) (*inputPlaneInvocation, error) {
functionPutInputsItem := pb.FunctionPutInputsItem_builder{
Idx: 0,
Input: input,
}.Build()
client, err := getOrCreateInputPlaneClient(inputPlaneURL)
if err != nil {
return nil, err
}
attemptStartResp, err := client.AttemptStart(ctx, pb.AttemptStartRequest_builder{
FunctionId: functionId,
Input: functionPutInputsItem,
}.Build())
if err != nil {
return nil, err
}
return &inputPlaneInvocation{
client: client,
functionId: functionId,
input: functionPutInputsItem,
attemptToken: attemptStartResp.GetAttemptToken(),
ctx: ctx,
}, nil
}

// awaitOutput waits for the output with an optional timeout.
func (i *inputPlaneInvocation) awaitOutput(timeout *time.Duration) (any, error) {
return pollFunctionOutput(i.ctx, i.getOutput, timeout)
}

// getOutput fetches the output for the current attempt.
func (i *inputPlaneInvocation) getOutput(timeout time.Duration) (*pb.FunctionGetOutputsItem, error) {
resp, err := i.client.AttemptAwait(i.ctx, pb.AttemptAwaitRequest_builder{
AttemptToken: i.attemptToken,
RequestedAt: timeNowSeconds(),
TimeoutSecs: float32(timeout.Seconds()),
}.Build())
if err != nil {
return nil, fmt.Errorf("AttemptAwait failed: %w", err)
}
return resp.GetOutput(), nil
}

// retry retries the invocation.
func (i *inputPlaneInvocation) retry(retryCount uint32) error {
// We ignore retryCount - it is used only by controlPlaneInvocation.
resp, err := i.client.AttemptRetry(context.Background(), pb.AttemptRetryRequest_builder{
FunctionId: i.functionId,
Input: i.input,
AttemptToken: i.attemptToken,
}.Build())
if err != nil {
return err
}
i.attemptToken = resp.GetAttemptToken()
return nil
}

// getOutput is a function type that takes a timeout and returns a FunctionGetOutputsItem or nil, and an error.
// Used by `pollForOutputs` to fetch from either the control plane or the input plane, depending on the implementation.
type getOutput func(timeout time.Duration) (*pb.FunctionGetOutputsItem, error)

// pollFunctionOutput repeatedly tries to fetch an output using the provided `getOutput` function, and the specified
// timeout value. We use a timeout value of 55 seconds if the caller does not specify a timeout value, or if the
// specified timeout value is greater than 55 seconds.
func pollFunctionOutput(ctx context.Context, getOutput getOutput, timeout *time.Duration) (any, error) {
startTime := time.Now()
pollTimeout := outputsTimeout
if timeout != nil {
Expand All @@ -85,23 +184,14 @@ func pollFunctionOutput(ctx context.Context, functionCallId string, timeout *tim
}

for {
response, err := client.FunctionGetOutputs(ctx, pb.FunctionGetOutputsRequest_builder{
FunctionCallId: functionCallId,
MaxValues: 1,
Timeout: float32(pollTimeout.Seconds()),
LastEntryId: "0-0",
ClearOnSuccess: true,
RequestedAt: timeNowSeconds(),
}.Build())
output, err := getOutput(pollTimeout)
if err != nil {
return nil, fmt.Errorf("FunctionGetOutputs failed: %w", err)
return nil, err
}

// Output serialization may fail if any of the output items can't be deserialized
// into a supported Go type. Users are expected to serialize outputs correctly.
outputs := response.GetOutputs()
if len(outputs) > 0 {
return processResult(ctx, outputs[0].GetResult(), outputs[0].GetDataFormat())
if output != nil {
return processResult(ctx, output.GetResult(), output.GetDataFormat())
}

if timeout != nil {
Expand Down
13 changes: 13 additions & 0 deletions modal-go/test/function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,16 @@ func TestFunctionNotFound(t *testing.T) {
_, err := modal.FunctionLookup(context.Background(), "libmodal-test-support", "not_a_real_function", nil)
g.Expect(err).Should(gomega.BeAssignableToTypeOf(modal.NotFoundError{}))
}

func TestFunctionCallInputPlane(t *testing.T) {
t.Parallel()
g := gomega.NewWithT(t)

function, err := modal.FunctionLookup(context.Background(), "libmodal-test-support", "input_plane", nil)
g.Expect(err).ShouldNot(gomega.HaveOccurred())

// Try the same, but with args.
result, err := function.Remote([]any{"hello"}, nil)
g.Expect(err).ShouldNot(gomega.HaveOccurred())
g.Expect(result).Should(gomega.Equal("output: hello"))
}
43 changes: 42 additions & 1 deletion modal-js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import { getProfile, type Profile } from "./config";

const defaultProfile = getProfile(process.env["MODAL_PROFILE"]);

let modalAuthToken: string | undefined;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a global variable right now and will end up shared between multiple clients. Instead I think it needs to be declared in the authMiddleware() factory between line 19 function authMiddleware(...) and line 20 return async function* authMiddleware(...), so it doesn't end up becoming accidentally shared global state.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually intentionally share the token between the control plane client and the input plane client. Here's the flow:

  1. We call the control plane, and do not specify an auth token (because we don't have one yet)
  2. The control plane sees there is no auth token - so it generates one, and returns it to the client
  3. The client the specifies that auth token in every subsequent request to both the control plane and the input plane.
    I will add a comment explaining this.


/** gRPC client middleware to add auth token to request. */
function authMiddleware(profile: Profile): ClientMiddleware {
return async function* authMiddleware<Request, Response>(
Expand All @@ -36,6 +38,29 @@ function authMiddleware(profile: Profile): ClientMiddleware {
options.metadata.set("x-modal-client-version", "1.0.0"); // CLIENT VERSION: Behaves like this Python SDK version
options.metadata.set("x-modal-token-id", tokenId);
options.metadata.set("x-modal-token-secret", tokenSecret);
if (modalAuthToken) {
options.metadata.set("x-modal-auth-token", modalAuthToken);
}

// We receive an auth token from the control plane on our first request. We then include that auth token in every
// subsequent request to both the control plane and the input plane. The python server returns it in the trailers,
// the worker returns it in the headers.
const prevOnHeader = options.onHeader;
options.onHeader = (header) => {
const token = header.get("x-modal-auth-token");
if (token) {
modalAuthToken = token;
}
prevOnHeader?.(header);
};
const prevOnTrailer = options.onTrailer;
options.onTrailer = (trailer) => {
const token = trailer.get("x-modal-auth-token");
if (token) {
modalAuthToken = token;
}
prevOnTrailer?.(trailer);
};
return yield* call.next(call.request, options);
};
}
Expand Down Expand Up @@ -199,6 +224,23 @@ const retryMiddleware: ClientMiddleware<RetryOptions> =
}
};

/** Map of server URL to input-plane client. */
const inputPlaneClients: Record<string, ReturnType<typeof createClient>> = {};

/** Returns a client for the given server URL, creating it if it doesn't exist. */
export const getOrCreateInputPlaneClient = (
serverUrl: string,
): ReturnType<typeof createClient> => {
const client = inputPlaneClients[serverUrl];
if (client) {
return client;
}
const profile = { ...clientProfile, serverUrl };
const newClient = createClient(profile);
inputPlaneClients[serverUrl] = newClient;
return newClient;
};

function createClient(profile: Profile) {
// Channels don't do anything until you send a request on them.
// Ref: https://github.com/modal-labs/modal-client/blob/main/modal/_utils/grpc_utils.py
Expand All @@ -207,7 +249,6 @@ function createClient(profile: Profile) {
"grpc.max_send_message_length": 100 * 1024 * 1024,
"grpc-node.flow_control_window": 64 * 1024 * 1024,
});

return createClientFactory()
.use(authMiddleware(profile))
.use(retryMiddleware)
Expand Down
Loading