| 
 | 1 | +package sphinx  | 
 | 2 | + | 
 | 3 | +import (  | 
 | 4 | +	"crypto/hmac"  | 
 | 5 | +	"crypto/sha256"  | 
 | 6 | +	"io"  | 
 | 7 | +)  | 
 | 8 | + | 
 | 9 | +type payloadSource byte  | 
 | 10 | + | 
 | 11 | +const (  | 
 | 12 | +	// payloadIntermediateNode is a marker to signal that this attributable  | 
 | 13 | +	// error payload is originating from a node between the payer and the error  | 
 | 14 | +	// source.  | 
 | 15 | +	payloadIntermediateNode payloadSource = 0  | 
 | 16 | + | 
 | 17 | +	// payloadErrorNode is a marker to signal that this attributable error  | 
 | 18 | +	// payload is originating from the error source.  | 
 | 19 | +	payloadErrorNode payloadSource = 1  | 
 | 20 | +)  | 
 | 21 | + | 
 | 22 | +type AttributableErrorStructure struct {  | 
 | 23 | +	HopCount        int  | 
 | 24 | +	FixedPayloadLen int  | 
 | 25 | +}  | 
 | 26 | + | 
 | 27 | +type attributableErrorBase struct {  | 
 | 28 | +	maxHops             int  | 
 | 29 | +	totalHmacs          int  | 
 | 30 | +	allHmacsLen         int  | 
 | 31 | +	hmacsAndPayloadsLen int  | 
 | 32 | +	allPayloadsLen      int  | 
 | 33 | +	payloadLen          int  | 
 | 34 | +	payloadDataLen      int  | 
 | 35 | +}  | 
 | 36 | + | 
 | 37 | +func newAttributableErrorBase(  | 
 | 38 | +	structure *AttributableErrorStructure) attributableErrorBase {  | 
 | 39 | + | 
 | 40 | +	var (  | 
 | 41 | +		payloadDataLen = structure.FixedPayloadLen  | 
 | 42 | + | 
 | 43 | +		// payloadLen is the size of the per-node payload. It consists of a  | 
 | 44 | +		// 1-byte payload type followed by the payload data.  | 
 | 45 | +		payloadLen = 1 + payloadDataLen  | 
 | 46 | + | 
 | 47 | +		totalHmacs = (structure.HopCount *  | 
 | 48 | +			(structure.HopCount + 1)) / 2  | 
 | 49 | + | 
 | 50 | +		allHmacsLen         = totalHmacs * sha256.Size  | 
 | 51 | +		allPayloadsLen      = payloadLen * structure.HopCount  | 
 | 52 | +		hmacsAndPayloadsLen = allHmacsLen + allPayloadsLen  | 
 | 53 | +	)  | 
 | 54 | + | 
 | 55 | +	return attributableErrorBase{  | 
 | 56 | +		totalHmacs:          totalHmacs,  | 
 | 57 | +		allHmacsLen:         allHmacsLen,  | 
 | 58 | +		hmacsAndPayloadsLen: hmacsAndPayloadsLen,  | 
 | 59 | +		allPayloadsLen:      allPayloadsLen,  | 
 | 60 | +		maxHops:             structure.HopCount,  | 
 | 61 | +		payloadLen:          payloadLen,  | 
 | 62 | +		payloadDataLen:      payloadDataLen,  | 
 | 63 | +	}  | 
 | 64 | +}  | 
 | 65 | + | 
 | 66 | +// getMsgComponents splits a complete failure message into its components  | 
 | 67 | +// without re-allocating memory.  | 
 | 68 | +func (o *attributableErrorBase) getMsgComponents(data []byte) ([]byte, []byte,  | 
 | 69 | +	[]byte) {  | 
 | 70 | + | 
 | 71 | +	payloads := data[len(data)-o.hmacsAndPayloadsLen : len(data)-o.allHmacsLen]  | 
 | 72 | +	hmacs := data[len(data)-o.allHmacsLen:]  | 
 | 73 | +	message := data[:len(data)-o.hmacsAndPayloadsLen]  | 
 | 74 | + | 
 | 75 | +	return message, payloads, hmacs  | 
 | 76 | +}  | 
 | 77 | + | 
 | 78 | +// calculateHmac calculates an hmac given a shared secret and a presumed  | 
 | 79 | +// position in the path. Position is expressed as the distance to the error  | 
 | 80 | +// source. The error source itself is at position 0.  | 
 | 81 | +func (o *attributableErrorBase) calculateHmac(sharedSecret Hash256,  | 
 | 82 | +	position int, message, payloads, hmacs []byte) []byte {  | 
 | 83 | + | 
 | 84 | +	umKey := generateKey("um", &sharedSecret)  | 
 | 85 | +	hash := hmac.New(sha256.New, umKey[:])  | 
 | 86 | + | 
 | 87 | +	// Include message.  | 
 | 88 | +	_, _ = hash.Write(message)  | 
 | 89 | + | 
 | 90 | +	// Include payloads including our own.  | 
 | 91 | +	_, _ = hash.Write(payloads[:(o.maxHops-position)*o.payloadLen])  | 
 | 92 | + | 
 | 93 | +	// Include downstream hmacs.  | 
 | 94 | +	writeDownstreamHmacs(position, o.maxHops, hmacs, hash)  | 
 | 95 | + | 
 | 96 | +	return hash.Sum(nil)  | 
 | 97 | +}  | 
 | 98 | + | 
 | 99 | +// writeDownstreamHmacs writes the hmacs of downstream nodes that are relevant  | 
 | 100 | +// for the given position to a writer instance.  | 
 | 101 | +func writeDownstreamHmacs(position, maxHops int, hmacs []byte, w io.Writer) {  | 
 | 102 | +	// Track the index of the next hmac to write in a variable. The first  | 
 | 103 | +	// maxHops slots are reserved for the hmacs of the current hop and can  | 
 | 104 | +	// therefore be skipped. The first hmac to write is part of the block of  | 
 | 105 | +	// hmacs that was written by the first downstream node. Which hmac exactly  | 
 | 106 | +	// is determined by the assumed position of the current node.  | 
 | 107 | +	var hmacIdx = maxHops + position  | 
 | 108 | + | 
 | 109 | +	// Iterate over all downstream nodes.  | 
 | 110 | +	for j := 0; j < maxHops-position-1; j++ {  | 
 | 111 | +		_, _ = w.Write(  | 
 | 112 | +			hmacs[hmacIdx*sha256.Size : (hmacIdx+1)*sha256.Size],  | 
 | 113 | +		)  | 
 | 114 | + | 
 | 115 | +		// Calculate the total number of hmacs in the block of the current  | 
 | 116 | +		// downstream node.  | 
 | 117 | +		blockSize := maxHops - j - 1  | 
 | 118 | + | 
 | 119 | +		// Skip to the next block. The new hmac index will point to the hmac  | 
 | 120 | +		// that corresponds to the next downstream node which is one step closer  | 
 | 121 | +		// to the assumed error source.  | 
 | 122 | +		hmacIdx += blockSize  | 
 | 123 | +	}  | 
 | 124 | +}  | 
0 commit comments