Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ const ERR_GET_SESSION_FAILED = -102
const ERR_ENCRYPT_FAILED = -103
const ERR_DECRYPT_FAILED = -104
const ERR_BAD_CONFIG = -105
const ERR_PANIC = -106

const EstimatedEncryptionOverhead = 48
const EstimatedEnvelopeOverhead = 185
Expand Down
6 changes: 4 additions & 2 deletions internal/asherah/asherah.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (f logFunc) Debugf(format string, v ...interface{}) {
}

func Setup(options *Options) error {
if atomic.LoadInt32(&globalInitialized) == 1 {
if !atomic.CompareAndSwapInt32(&globalInitialized, 0, 1) {
log.ErrorLog("Failed to initialize asherah: already initialized")
return ErrAsherahAlreadyInitialized
}
Expand Down Expand Up @@ -73,17 +73,18 @@ func Setup(options *Options) error {

if globalSessionFactory == nil {
log.ErrorLog("Failed to create session factory")
atomic.StoreInt32(&globalInitialized, 0)
return ErrAsherahFailedInitialization
}

atomic.StoreInt32(&globalInitialized, 1)
return nil
}

func Shutdown() {
if atomic.CompareAndSwapInt32(&globalInitialized, 1, 0) {
globalSessionFactory.Close()
globalSessionFactory = nil
closeConnection()
}
}

Expand All @@ -108,6 +109,7 @@ func Encrypt(partitionId string, data []byte) (*appencryption.DataRowRecord, err
func Decrypt(partitionId string, drr *appencryption.DataRowRecord) ([]byte, error) {
// Atomic read to prevent race with Shutdown()
if atomic.LoadInt32(&globalInitialized) == 0 {
log.ErrorLog("Failed to decrypt data: asherah is not initialized")
return nil, ErrAsherahNotInitialized
}

Expand Down
7 changes: 7 additions & 0 deletions internal/asherah/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ var (
dbconnection *sql.DB
)

func closeConnection() {
if dbconnection != nil {
dbconnection.Close()
dbconnection = nil
}
}

func newConnection(dbdriver string, connStr string) (*sql.DB, error) {
var err error
if dbconnection == nil {
Expand Down
67 changes: 41 additions & 26 deletions libasherah.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"os"
"sync/atomic"

"github.com/godaddy/cobhan-go"

Expand All @@ -19,7 +20,7 @@ import (
)

var EstimatedIntermediateKeyOverhead = 0
var DisableZeroCopy = false
var disableZeroCopy atomic.Bool

func main() {
}
Expand All @@ -41,32 +42,35 @@ type Env map[string]string
https://github.com/golang/go/issues/27693
*/
//export SetEnv
func SetEnv(envJson unsafe.Pointer) int32 {
func SetEnv(envJson unsafe.Pointer) (result int32) {
defer func() {
if r := recover(); r != nil {
log.ErrorLogf("SetEnv: Panic: %v", r)
panic(r)
result = ERR_PANIC
}
}()

cobhan.AllowTempFileBuffers(false)
env := Env{}

result := cobhan.BufferToJsonStruct(envJson, &env)
result = cobhan.BufferToJsonStruct(envJson, &env)
if result != cobhan.ERR_NONE {
log.ErrorLogf("Failed to deserialize environment JSON string %v", cobhan.CobhanErrorToString(result))
return result
}

for k, v := range env {
os.Setenv(k, v)
if err := os.Setenv(k, v); err != nil {
log.ErrorLogf("Failed to set environment variable %v: %v", k, err)
return ERR_BAD_CONFIG
}
}

return cobhan.ERR_NONE
}

//export SetupJson
func SetupJson(configJson unsafe.Pointer) int32 {
func SetupJson(configJson unsafe.Pointer) (result int32) {
defer func() {
if r := recover(); r != nil {
log.ErrorLogf("SetupJson: Panic: %v", r)
Expand All @@ -76,12 +80,12 @@ func SetupJson(configJson unsafe.Pointer) int32 {

cobhan.AllowTempFileBuffers(false)
options := &asherah.Options{}
result := cobhan.BufferToJsonStruct(configJson, options)
result = cobhan.BufferToJsonStruct(configJson, options)
if result != cobhan.ERR_NONE {
log.ErrorLogf("Failed to deserialize configuration string %v", cobhan.CobhanErrorToString(result))
configString, stringResult := cobhan.BufferToString(configJson)
if stringResult != cobhan.ERR_NONE {
log.ErrorLogf("Could not convert configJson to string: %v", cobhan.CobhanErrorToString(result))
log.ErrorLogf("Could not convert configJson to string: %v", cobhan.CobhanErrorToString(stringResult))
return result
}
log.ErrorLogf("Could not deserialize: %v", configString)
Expand All @@ -93,7 +97,7 @@ func SetupJson(configJson unsafe.Pointer) int32 {
log.DebugLog("Successfully deserialized config JSON")

EstimatedIntermediateKeyOverhead = len(options.ProductID) + len(options.ServiceName)
DisableZeroCopy = options.DisableZeroCopy
disableZeroCopy.Store(options.DisableZeroCopy)

err := asherah.Setup(options)
if err == asherah.ErrAsherahAlreadyInitialized {
Expand Down Expand Up @@ -125,27 +129,30 @@ func EstimateBufferInt(dataLen int, partitionLen int) int {

//export Decrypt
func Decrypt(partitionIdPtr unsafe.Pointer, encryptedDataPtr unsafe.Pointer, encryptedKeyPtr unsafe.Pointer,
created int64, parentKeyIdPtr unsafe.Pointer, parentKeyCreated int64, outputDecryptedDataPtr unsafe.Pointer) int32 {
created int64, parentKeyIdPtr unsafe.Pointer, parentKeyCreated int64, outputDecryptedDataPtr unsafe.Pointer) (result int32) {
defer func() {
if r := recover(); r != nil {
log.ErrorLogf("Decrypt: Panic: %v", r)
panic(r)
result = ERR_PANIC
}
}()

encryptedData, result := cobhan.BufferToBytes(encryptedDataPtr)
var encryptedData []byte
encryptedData, result = cobhan.BufferToBytes(encryptedDataPtr)
if result != cobhan.ERR_NONE {
log.ErrorLogf("Decrypt failed: Failed to convert encryptedDataPtr cobhan buffer to bytes %v", cobhan.CobhanErrorToString(result))
return result
}

encryptedKey, result := cobhan.BufferToBytes(encryptedKeyPtr)
var encryptedKey []byte
encryptedKey, result = cobhan.BufferToBytes(encryptedKeyPtr)
if result != cobhan.ERR_NONE {
log.ErrorLogf("Decrypt failed: Failed to convert encryptedKeyPtr cobhan buffer to bytes %v", cobhan.CobhanErrorToString(result))
return result
}

parentKeyId, result := cobhan.BufferToString(parentKeyIdPtr)
var parentKeyId string
parentKeyId, result = cobhan.BufferToString(parentKeyIdPtr)
if result != cobhan.ERR_NONE {
log.ErrorLogf("Decrypt failed: Failed to convert parentKeyIdPtr cobhan buffer to string %v", cobhan.CobhanErrorToString(result))
return result
Expand All @@ -163,7 +170,9 @@ func Decrypt(partitionIdPtr unsafe.Pointer, encryptedDataPtr unsafe.Pointer, enc
},
}

data, result, err := decryptData(partitionIdPtr, &drr)
var data []byte
var err error
data, result, err = decryptData(partitionIdPtr, &drr)
if result != cobhan.ERR_NONE {
log.ErrorLogf("Failed to decrypt data %v", cobhan.CobhanErrorToString(result))
log.ErrorLogf("Decrypt: decryptData returned %v", err)
Expand All @@ -176,15 +185,17 @@ func Decrypt(partitionIdPtr unsafe.Pointer, encryptedDataPtr unsafe.Pointer, enc
//export Encrypt
func Encrypt(partitionIdPtr unsafe.Pointer, dataPtr unsafe.Pointer, outputEncryptedDataPtr unsafe.Pointer,
outputEncryptedKeyPtr unsafe.Pointer, outputCreatedPtr unsafe.Pointer, outputParentKeyIdPtr unsafe.Pointer,
outputParentKeyCreatedPtr unsafe.Pointer) int32 {
outputParentKeyCreatedPtr unsafe.Pointer) (result int32) {
defer func() {
if r := recover(); r != nil {
log.ErrorLogf("Encrypt: Panic: %v", r)
panic(r)
result = ERR_PANIC
}
}()

drr, result, err := encryptData(partitionIdPtr, dataPtr)
var drr *appencryption.DataRowRecord
var err error
drr, result, err = encryptData(partitionIdPtr, dataPtr)
if result != cobhan.ERR_NONE {
log.ErrorLogf("Failed to encrypt data %v", cobhan.CobhanErrorToString(result))
log.ErrorLogf("Encrypt failed: encryptData returned %v", err)
Expand Down Expand Up @@ -226,15 +237,17 @@ func Encrypt(partitionIdPtr unsafe.Pointer, dataPtr unsafe.Pointer, outputEncryp
}

//export EncryptToJson
func EncryptToJson(partitionIdPtr unsafe.Pointer, dataPtr unsafe.Pointer, jsonPtr unsafe.Pointer) int32 {
func EncryptToJson(partitionIdPtr unsafe.Pointer, dataPtr unsafe.Pointer, jsonPtr unsafe.Pointer) (result int32) {
defer func() {
if r := recover(); r != nil {
log.ErrorLogf("EncryptToJson: Panic: %v", r)
panic(r)
result = ERR_PANIC
}
}()

drr, result, err := encryptData(partitionIdPtr, dataPtr)
var drr *appencryption.DataRowRecord
var err error
drr, result, err = encryptData(partitionIdPtr, dataPtr)
if result != cobhan.ERR_NONE {
log.ErrorLogf("Failed to encrypt data %v", cobhan.CobhanErrorToString(result))
log.ErrorLogf("EncryptToJson failed: encryptData returned %v", err)
Expand All @@ -258,22 +271,24 @@ func EncryptToJson(partitionIdPtr unsafe.Pointer, dataPtr unsafe.Pointer, jsonPt
}

//export DecryptFromJson
func DecryptFromJson(partitionIdPtr unsafe.Pointer, jsonPtr unsafe.Pointer, dataPtr unsafe.Pointer) int32 {
func DecryptFromJson(partitionIdPtr unsafe.Pointer, jsonPtr unsafe.Pointer, dataPtr unsafe.Pointer) (result int32) {
defer func() {
if r := recover(); r != nil {
log.ErrorLogf("DecryptFromJson: Panic: %v", r)
panic(r)
result = ERR_PANIC
}
}()

var drr appencryption.DataRowRecord
result := cobhan.BufferToJsonStruct(jsonPtr, &drr)
result = cobhan.BufferToJsonStruct(jsonPtr, &drr)
if result != cobhan.ERR_NONE {
log.ErrorLogf("DecryptFromJson failed: Failed to convert cobhan buffer to JSON structs %v", cobhan.CobhanErrorToString(result))
return result
}

data, result, err := decryptData(partitionIdPtr, &drr)
var data []byte
var err error
data, result, err = decryptData(partitionIdPtr, &drr)
if result != cobhan.ERR_NONE {
log.ErrorLogf("Failed to decrypt data %v", cobhan.CobhanErrorToString(result))
log.ErrorLogf("DecryptFromJson failed: decryptData returned %v", err)
Expand Down Expand Up @@ -306,7 +321,7 @@ func encryptData(partitionIdPtr unsafe.Pointer, dataPtr unsafe.Pointer) (*appenc
return nil, result, errors.New(errorMessage)
}

if DisableZeroCopy {
if disableZeroCopy.Load() {
dataCopy := make([]byte, len(data))
copy(dataCopy, data)
data = dataCopy
Expand Down
2 changes: 1 addition & 1 deletion libasherah_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ func TestEncryptDecryptRoundTripWithDefensiveCopy(t *testing.T) {
}
defer Shutdown()

if !DisableZeroCopy {
if !disableZeroCopy.Load() {
t.Error("DisableZeroCopy was not set by SetupJson")
}

Expand Down