diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java index f6fd7f21e0066..95e8957e2fec9 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java @@ -136,6 +136,7 @@ import org.apache.hadoop.thirdparty.protobuf.Message; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.hadoop.security.AuthorizationContext; /** An abstract IPC service. IPC calls take a single {@link Writable} as a * parameter, and return a {@link Writable} as their value. A service runs on @@ -835,6 +836,7 @@ public static class Call implements Schedulable, final byte[] clientId; private final Span span; // the trace span on the server side private final CallerContext callerContext; // the call context + private final byte[] authHeader; // the auth header private boolean deferredResponse = false; private int priorityLevel; // the priority level assigned by scheduler, 0 by default @@ -863,6 +865,11 @@ public Call(int id, int retryCount, Void ignore1, Void ignore2, Call(int id, int retryCount, RPC.RpcKind kind, byte[] clientId, Span span, CallerContext callerContext) { + this(id, retryCount, kind, clientId, span, callerContext, null); + } + + Call(int id, int retryCount, RPC.RpcKind kind, byte[] clientId, + Span span, CallerContext callerContext, byte[] authHeader) { this.callId = id; this.retryCount = retryCount; this.timestampNanos = Time.monotonicNowNanos(); @@ -871,6 +878,7 @@ public Call(int id, int retryCount, Void ignore1, Void ignore2, this.clientId = clientId; this.span = span; this.callerContext = callerContext; + this.authHeader = authHeader; this.clientStateId = Long.MIN_VALUE; this.isCallCoordinated = false; } @@ -1051,7 +1059,15 @@ private class RpcCall extends Call { RpcCall(Connection connection, int id, int retryCount, Writable param, RPC.RpcKind kind, byte[] clientId, Span span, CallerContext context) { - super(id, retryCount, kind, clientId, span, context); + this(connection, id, retryCount, param, kind, clientId, + span, context, new byte[0]); + } + + @SuppressWarnings("checkstyle:parameterNumber") + RpcCall(Connection connection, int id, int retryCount, + Writable param, RPC.RpcKind kind, byte[] clientId, + Span span, CallerContext context, byte[] authHeader) { + super(id, retryCount, kind, clientId, span, context, authHeader); this.connection = connection; this.rpcRequest = param; } @@ -2783,48 +2799,58 @@ private void processRpcRequest(RpcRequestHeaderProto header, .build(); } - RpcCall call = new RpcCall(this, header.getCallId(), - header.getRetryCount(), rpcRequest, - ProtoUtil.convert(header.getRpcKind()), - header.getClientId().toByteArray(), span, callerContext); - - // Save the priority level assignment by the scheduler - call.setPriorityLevel(callQueue.getPriorityLevel(call)); - call.markCallCoordinated(false); - if(alignmentContext != null && call.rpcRequest != null && - (call.rpcRequest instanceof ProtobufRpcEngine2.RpcProtobufRequest)) { - // if call.rpcRequest is not RpcProtobufRequest, will skip the following - // step and treat the call as uncoordinated. As currently only certain - // ClientProtocol methods request made through RPC protobuf needs to be - // coordinated. - String methodName; - String protoName; - ProtobufRpcEngine2.RpcProtobufRequest req = - (ProtobufRpcEngine2.RpcProtobufRequest) call.rpcRequest; - try { - methodName = req.getRequestHeader().getMethodName(); - protoName = req.getRequestHeader().getDeclaringClassProtocolName(); - if (alignmentContext.isCoordinatedCall(protoName, methodName)) { - call.markCallCoordinated(true); - long stateId; - stateId = alignmentContext.receiveRequestState( - header, getMaxIdleTime()); - call.setClientStateId(stateId); + // Set AuthorizationContext for this thread if present + byte[] authHeader = null; + try { + if (header.hasAuthorizationHeader()) { + authHeader = header.getAuthorizationHeader().toByteArray(); + } + + RpcCall call = new RpcCall(this, header.getCallId(), + header.getRetryCount(), rpcRequest, + ProtoUtil.convert(header.getRpcKind()), + header.getClientId().toByteArray(), span, callerContext, authHeader); + + // Save the priority level assignment by the scheduler + call.setPriorityLevel(callQueue.getPriorityLevel(call)); + call.markCallCoordinated(false); + if (alignmentContext != null && call.rpcRequest != null && + (call.rpcRequest instanceof ProtobufRpcEngine2.RpcProtobufRequest)) { + // if call.rpcRequest is not RpcProtobufRequest, will skip the following + // step and treat the call as uncoordinated. As currently only certain + // ClientProtocol methods request made through RPC protobuf needs to be + // coordinated. + String methodName; + String protoName; + ProtobufRpcEngine2.RpcProtobufRequest req = + (ProtobufRpcEngine2.RpcProtobufRequest) call.rpcRequest; + try { + methodName = req.getRequestHeader().getMethodName(); + protoName = req.getRequestHeader().getDeclaringClassProtocolName(); + if (alignmentContext.isCoordinatedCall(protoName, methodName)) { + call.markCallCoordinated(true); + long stateId; + stateId = alignmentContext.receiveRequestState( + header, getMaxIdleTime()); + call.setClientStateId(stateId); + } + } catch (IOException ioe) { + throw new RpcServerException("Processing RPC request caught ", ioe); } - } catch (IOException ioe) { - throw new RpcServerException("Processing RPC request caught ", ioe); } - } - try { - internalQueueCall(call); - } catch (RpcServerException rse) { - throw rse; - } catch (IOException ioe) { - throw new FatalRpcServerException( - RpcErrorCodeProto.ERROR_RPC_SERVER, ioe); + try { + internalQueueCall(call); + } catch (RpcServerException rse) { + throw rse; + } catch (IOException ioe) { + throw new FatalRpcServerException( + RpcErrorCodeProto.ERROR_RPC_SERVER, ioe); + } + incRpcCount(); // Increment the rpc count + } finally { + AuthorizationContext.clear(); } - incRpcCount(); // Increment the rpc count } /** @@ -3046,6 +3072,7 @@ public void run() { } // always update the current call context CallerContext.setCurrent(call.callerContext); + AuthorizationContext.setCurrentAuthorizationHeader(call.authHeader); UserGroupInformation remoteUser = call.getRemoteUser(); connDropped = !call.isOpen(); if (remoteUser != null) { diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/AuthorizationContext.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/AuthorizationContext.java new file mode 100644 index 0000000000000..4c4b32e11517b --- /dev/null +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/AuthorizationContext.java @@ -0,0 +1,39 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.security; + +/** + * Utility for managing a thread-local authorization header for RPC calls. + */ +public final class AuthorizationContext { + private static final ThreadLocal AUTH_HEADER = new ThreadLocal<>(); + + private AuthorizationContext() {} + + public static void setCurrentAuthorizationHeader(byte[] header) { + AUTH_HEADER.set(header); + } + + public static byte[] getCurrentAuthorizationHeader() { + return AUTH_HEADER.get(); + } + + public static void clear() { + AUTH_HEADER.remove(); + } +} diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/ProtoUtil.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/ProtoUtil.java index 883c19c5e7750..307be15db6f34 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/ProtoUtil.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/ProtoUtil.java @@ -32,6 +32,7 @@ import org.apache.hadoop.tracing.Span; import org.apache.hadoop.tracing.Tracer; import org.apache.hadoop.tracing.TraceUtils; +import org.apache.hadoop.security.AuthorizationContext; import org.apache.hadoop.thirdparty.protobuf.ByteString; @@ -203,6 +204,12 @@ public static RpcRequestHeaderProto makeRpcRequestHeader(RPC.RpcKind rpcKind, result.setCallerContext(contextBuilder); } + // Add authorization header if present + byte[] authzHeader = AuthorizationContext.getCurrentAuthorizationHeader(); + if (authzHeader != null) { + result.setAuthorizationHeader(ByteString.copyFrom(authzHeader)); + } + // Add alignment context if it is not null if (alignmentContext != null) { alignmentContext.updateRequestState(result); diff --git a/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto b/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto index d9becf722e982..19bdc96726b0e 100644 --- a/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto +++ b/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto @@ -95,6 +95,8 @@ message RpcRequestHeaderProto { // the header for the RpcRequest // The client should not interpret these bytes, but only forward bytes // received from RpcResponseHeaderProto.routerFederatedState. optional bytes routerFederatedState = 9; + // Authorization header for passing opaque credentials or tokens + optional bytes authorizationHeader = 10; } diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java new file mode 100644 index 0000000000000..417511691e2ae --- /dev/null +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java @@ -0,0 +1,71 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.security; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class TestAuthorizationContext { + + @Test + public void testSetAndGetAuthorizationHeader() { + byte[] header = "my-auth-header".getBytes(); + AuthorizationContext.setCurrentAuthorizationHeader(header); + Assertions.assertArrayEquals(header, AuthorizationContext.getCurrentAuthorizationHeader()); + AuthorizationContext.clear(); + } + + @Test + public void testClearAuthorizationHeader() { + byte[] header = "clear-me".getBytes(); + AuthorizationContext.setCurrentAuthorizationHeader(header); + AuthorizationContext.clear(); + Assertions.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); + } + + @Test + public void testThreadLocalIsolation() throws Exception { + byte[] mainHeader = "main-thread".getBytes(); + AuthorizationContext.setCurrentAuthorizationHeader(mainHeader); + Thread t = new Thread(() -> { + Assertions.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); + byte[] threadHeader = "other-thread".getBytes(); + AuthorizationContext.setCurrentAuthorizationHeader(threadHeader); + Assertions.assertArrayEquals( + threadHeader, + AuthorizationContext.getCurrentAuthorizationHeader()); + AuthorizationContext.clear(); + Assertions.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); + }); + t.start(); + t.join(); + // Main thread should still have its header + Assertions.assertArrayEquals(mainHeader, AuthorizationContext.getCurrentAuthorizationHeader()); + AuthorizationContext.clear(); + } + + @Test + public void testNullAndEmptyHeader() { + AuthorizationContext.setCurrentAuthorizationHeader(null); + Assertions.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); + byte[] empty = new byte[0]; + AuthorizationContext.setCurrentAuthorizationHeader(empty); + Assertions.assertArrayEquals(empty, AuthorizationContext.getCurrentAuthorizationHeader()); + AuthorizationContext.clear(); + } +} diff --git a/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java b/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java new file mode 100644 index 0000000000000..025ac57c58be4 --- /dev/null +++ b/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java @@ -0,0 +1,81 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hdfs.server.namenode; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.hdfs.HdfsConfiguration; +import org.apache.hadoop.hdfs.MiniDFSCluster; +import org.apache.hadoop.security.AuthorizationContext; +import org.junit.jupiter.api.Test; + +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.apache.hadoop.hdfs.DFSConfigKeys.DFS_NAMENODE_AUDIT_LOGGERS_KEY; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class TestAuthorizationHeaderPropagation { + + public static class HeaderCapturingAuditLogger implements AuditLogger { + public static final List CAPTURED_HEADERS = new ArrayList<>(); + @Override + public void initialize(Configuration conf) {} + @Override + public void logAuditEvent(boolean succeeded, String userName, InetAddress addr, + String cmd, String src, String dst, FileStatus stat) { + byte[] header = AuthorizationContext.getCurrentAuthorizationHeader(); + CAPTURED_HEADERS.add(header == null ? null : Arrays.copyOf(header, header.length)); + } + } + + @Test + public void testAuthorizationHeaderPerRpc() throws Exception { + Configuration conf = new HdfsConfiguration(); + conf.set(DFS_NAMENODE_AUDIT_LOGGERS_KEY, HeaderCapturingAuditLogger.class.getName()); + MiniDFSCluster cluster = new MiniDFSCluster.Builder(conf).build(); + try { + cluster.waitClusterUp(); + HeaderCapturingAuditLogger.CAPTURED_HEADERS.clear(); + FileSystem fs = cluster.getFileSystem(); + // First RPC with header1 + byte[] header1 = "header-one".getBytes(); + AuthorizationContext.setCurrentAuthorizationHeader(header1); + fs.mkdirs(new Path("/authz1")); + AuthorizationContext.clear(); + // Second RPC with header2 + byte[] header2 = "header-two".getBytes(); + AuthorizationContext.setCurrentAuthorizationHeader(header2); + fs.mkdirs(new Path("/authz2")); + AuthorizationContext.clear(); + // Third RPC with no header + fs.mkdirs(new Path("/authz3")); + // Now assert + assertArrayEquals(header1, HeaderCapturingAuditLogger.CAPTURED_HEADERS.get(0)); + assertArrayEquals(header2, HeaderCapturingAuditLogger.CAPTURED_HEADERS.get(1)); + assertNull(HeaderCapturingAuditLogger.CAPTURED_HEADERS.get(2)); + } finally { + cluster.shutdown(); + } + } +}