Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions modules/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"os"
"os/exec"
"path/filepath"
"reflect"
"strconv"
"strings"
"sync"
Expand All @@ -33,9 +34,21 @@ import (
gossh "golang.org/x/crypto/ssh"
)

type contextKey string

const giteaKeyID = contextKey("gitea-key-id")
// The ssh auth overall works like this:
// NewServerConn:
// serverHandshake+serverAuthenticate:
// PublicKeyCallback:
// PublicKeyHandler (our code):
// clear(ctx.Permissions) and set ctx.Permissions.giteaKeyID = keyID
// pubKey.Verify
// return ctx.Permissions // only reaches here, the pub key is really authenticated
// set conn.Permissions from serverAuthenticate
// sessionHandler(conn)
//
// Then sessionHandler should only use the "verified keyID" from the conn.
// Otherwise, if a users provides 2 keys A and B, if A succeeds to authenticate, sessionHandler will see B's keyID

const giteaPermissionExtensionKeyID = "gitea-perm-ext-key-id"

func getExitStatusFromError(err error) int {
if err == nil {
Expand All @@ -61,8 +74,26 @@ func getExitStatusFromError(err error) int {
return waitStatus.ExitStatus()
}

type sessionPartial struct {
sync.Mutex
gossh.Channel
conn *gossh.ServerConn
}

func ptr[T any](intf any) *T {
// https://pkg.go.dev/unsafe#Pointer
// (1) Conversion of a *T1 to Pointer to *T2.
// Provided that T2 is no larger than T1 and that the two share an equivalent memory layout,
// this conversion allows reinterpreting data of one type as data of another type.
v := reflect.ValueOf(intf)
p := v.UnsafePointer()
return (*T)(p)
}

func sessionHandler(session ssh.Session) {
keyID := fmt.Sprintf("%d", session.Context().Value(giteaKeyID).(int64))
// it can't use session.Permissions() because it only use the ctx one, so we must use the original ssh conn
sshConn := ptr[sessionPartial](session)
keyID := sshConn.conn.Permissions.Extensions[giteaPermissionExtensionKeyID]

command := session.RawCommand()

Expand Down Expand Up @@ -164,6 +195,12 @@ func sessionHandler(session ssh.Session) {
}

func publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
setPermExt := func(keyID int64) {
ctx.Permissions().Permissions.Extensions = map[string]string{
giteaPermissionExtensionKeyID: fmt.Sprint(keyID),
}
}

if log.IsDebug() { // <- FingerprintSHA256 is kinda expensive so only calculate it if necessary
log.Debug("Handle Public Key: Fingerprint: %s from %s", gossh.FingerprintSHA256(key), ctx.RemoteAddr())
}
Expand Down Expand Up @@ -238,8 +275,7 @@ func publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
if log.IsDebug() { // <- FingerprintSHA256 is kinda expensive so only calculate it if necessary
log.Debug("Successfully authenticated: %s Certificate Fingerprint: %s Principal: %s", ctx.RemoteAddr(), gossh.FingerprintSHA256(key), principal)
}
ctx.SetValue(giteaKeyID, pkey.ID)

setPermExt(pkey.ID)
return true
}

Expand All @@ -266,8 +302,7 @@ func publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
if log.IsDebug() { // <- FingerprintSHA256 is kinda expensive so only calculate it if necessary
log.Debug("Successfully authenticated: %s Public Key Fingerprint: %s", ctx.RemoteAddr(), gossh.FingerprintSHA256(key))
}
ctx.SetValue(giteaKeyID, pkey.ID)

setPermExt(pkey.ID)
return true
}

Expand Down
32 changes: 32 additions & 0 deletions modules/ssh/ssh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package ssh

import (
"testing"

"github.com/stretchr/testify/assert"
)

type S1 struct {
a, b, c int
}

func (s S1) S1Func() {}

type S1Intf interface {
S1Func()
}

type S2 struct {
a, b int
}

func TestPtr(t *testing.T) {
s1 := &S1{1, 2, 3}
var intf S1Intf = s1
s2 := ptr[S2](intf)
assert.Equal(t, 1, s2.a)
assert.Equal(t, 2, s2.b)
}
Loading