diff --git a/src/emu/core/parser.go b/src/emu/core/parser.go index b183dc5..a67c808 100644 --- a/src/emu/core/parser.go +++ b/src/emu/core/parser.go @@ -12,6 +12,8 @@ package core import ( "encoding/binary" + "external/google/gopacket" + "external/google/gopacket/ip4defrag" "external/google/gopacket/layers" "fmt" "runtime" @@ -519,6 +521,8 @@ type Parser struct { eapol ParserCb ppp ParserCb Cdb *CCounterDb + + Defrag *ip4defrag.IPv4Defragmenter } func parserNotSupported(ps *ParserPacketState) int { @@ -578,6 +582,7 @@ func (o *Parser) Init(tctx *CThreadCtx) { o.mdns = parserNotSupported o.ppp = parserNotSupported o.Cdb = newParserStatsDb(&o.stats) + o.Defrag = ip4defrag.NewIPv4Defragmenter() } func (o *Parser) parsePacketL4(ps *ParserPacketState, @@ -830,10 +835,6 @@ func (o *Parser) ParsePacket(m *Mbuf) int { o.stats.errIPv4HeaderTooShort++ return PARSER_ERR } - if ipv4.IsFragment() { - o.stats.errIPv4Fragment++ - return PARSER_ERR - } hdr := ipv4.GetHeaderLen() if hdr < 20 { o.stats.errIPv4HeaderTooShort++ @@ -855,6 +856,56 @@ func (o *Parser) ParsePacket(m *Mbuf) int { o.stats.errIPv4cs++ return PARSER_ERR } + if ipv4.IsFragment() { + // Only handles fragmented IP packet containing UDP + if ipv4.GetNextProtocol() != uint8(layers.IPProtocolUDP) { + o.stats.errIPv4Fragment++ + return PARSER_ERR + } + + packet := gopacket.NewPacket(m.GetData(), layers.LayerTypeEthernet, gopacket.NoCopy) + ipv4Layer := packet.Layer(layers.LayerTypeIPv4) + if ipv4Layer == nil { + return PARSER_ERR + } + in := ipv4Layer.(*layers.IPv4) + out, err := o.Defrag.DefragIPv4(in) + if err != nil { + return PARSER_ERR + } + if out == nil { + // Packet is fragmented, wait for next fragment + return PARSER_OK + } + + // Decode defragmented packet + pb, ok := packet.(gopacket.PacketBuilder) + if !ok { + return PARSER_ERR + } + nextDecoder := out.NextLayerType() + err = nextDecoder.Decode(out.Payload, pb) + if err != nil { + return PARSER_ERR + } + // TODO: Call DiscardOlderThan + + // Encode defragmented packet to buffer + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} + err = gopacket.SerializePacket(buf, opts, packet) + + // Allocated a new Mbuf for the defragmented packet + m = o.tctx.MPool.Alloc(offset + hdr + uint16(len(buf.Bytes()))) + defer m.FreeMbuf() + m.Append(p[0 : offset+hdr]) // Append ethernet and ipv4 header + m.Append(buf.Bytes()) // Append defragmented payload + + p = m.GetData() + ipv4 = layers.IPv4Header(p[offset : offset+hdr]) + ipv4.SetLength(hdr + uint16(len(buf.Bytes()))) // Set correct payload length + ps.M = m + } l4len := ipv4.GetLength() - ipv4.GetHeaderLen() ps.L4 = offset + hdr offset = ps.L4 diff --git a/src/emu/plugins/transport/udp.go b/src/emu/plugins/transport/udp.go index 2f68996..5060f8d 100644 --- a/src/emu/plugins/transport/udp.go +++ b/src/emu/plugins/transport/udp.go @@ -94,19 +94,56 @@ func (o *UdpSocket) Write(buf []byte) (res SocketErr, queued bool) { if o.isClosed { return SeCONNECTION_IS_CLOSED, false } - var pkt udpPkt - - if uint16(len(buf)) > o.GetL7MTU() { - o.ctx.udpStats.udp_drop_msg_bigger_mtu++ - return SeENOBUFS, false + if o.resolve() == false { + return SeUNRESOLVED, false } + var pkt udpPkt - if o.buildDpkt(&pkt, buf) < 0 { - if !o.resolved { - return SeUNRESOLVED, false + mtu := o.GetL7MTU() + if uint16(len(buf)) > mtu { + if o.ipv6 { + o.ctx.udpStats.udp_drop_msg_bigger_mtu++ + return SeENOBUFS, false + } + // Send large data in multiple IPv4 packets, fragmented + + // The payload size in a fragment need to be a multiple of 8 + maxPayloadSize := uint16(mtu/8) * 8 + payload := buf[:maxPayloadSize] + payloadSize := uint16(len(payload)) + pkt.datalen = uint16(len(buf)) + + // Send first fragment containing UDP header + o.buildDpktFirstFragment(&pkt, payload) + o.ctx.udpStats.udp_sndpack++ + o.ctx.udpStats.udp_sndbyte += uint64(payloadSize) + o.tctx.Veth.Send(pkt.m) + dataSent := payloadSize + + // Send trailing segments + for { + if dataSent+maxPayloadSize > pkt.datalen { + payload = buf[dataSent:] + } else { + payload = buf[dataSent : dataSent+maxPayloadSize] + } + payloadSize = uint16(len(payload)) + + o.buildDpktTrailingFragment(&pkt, payload, dataSent+UDP_HEADER_LEN) + o.ctx.udpStats.udp_sndpack++ + o.ctx.udpStats.udp_sndbyte += uint64(payloadSize) + o.tctx.Veth.Send(pkt.m) + + dataSent += payloadSize + if dataSent >= pkt.datalen { + break // All data is sent + } } - return SeENOBUFS, false + + return SeOK, true } + + o.buildDpkt(&pkt, buf) o.ctx.udpStats.udp_sndpack++ o.ctx.udpStats.udp_sndbyte += uint64(len(buf)) o.send(&pkt) @@ -159,10 +196,60 @@ func (o *UdpSocket) send(pkt *udpPkt) int { return 0 } -func (o *UdpSocket) buildDpkt(pkt *udpPkt, data []byte) int { - if o.resolve() == false { - return -1 +func (o *UdpSocket) buildDpktFirstFragment(pkt *udpPkt, data []byte) int { + tl := uint16(len(o.pktTemplate)) + dl := uint16(len(data)) + m := o.ns.AllocMbuf(tl + dl) + m.Append(o.pktTemplate) // template with IP and UDP header + m.Append(data) + pkt.m = m + + p := m.GetData() + l3 := o.l3Offset + + // Update fragment flags + f := (uint16(layers.IPv4MoreFragments) << 13) + binary.BigEndian.PutUint16(p[l3+6:l3+8], f) + + // Update UDP header length + binary.BigEndian.PutUint16(p[l3+24:l3+26], pkt.datalen+UDP_HEADER_LEN) + + // Update IPv4 header + ipv4 := layers.IPv4Header(p[l3 : l3+20]) + ipv4.SetLength(20 + uint16(len(data)) + UDP_HEADER_LEN) + ipv4.UpdateChecksum() + + return 0 +} + +func (o *UdpSocket) buildDpktTrailingFragment(pkt *udpPkt, data []byte, dataSent uint16) int { + tl := uint16(len(o.pktTemplate)) - UDP_HEADER_LEN + dl := uint16(len(data)) + m := o.ns.AllocMbuf(tl + dl) + m.Append(o.pktTemplate[:tl]) // template without UDP header + m.Append(data) + pkt.m = m + + p := m.GetData() + l3 := o.l3Offset + payload_size := uint16(len(data)) + + // Update fragment offset and flag if more fragments are required + f := dataSent / 8 + if dataSent+payload_size < pkt.datalen { + f += (uint16(layers.IPv4MoreFragments) << 13) } + binary.BigEndian.PutUint16(p[l3+6:l3+8], f) + + // Update IPv4 header + ipv4 := layers.IPv4Header(p[l3 : l3+20]) + ipv4.SetLength(20 + payload_size) + ipv4.UpdateChecksum() + + return 0 +} + +func (o *UdpSocket) buildDpkt(pkt *udpPkt, data []byte) int { dl := uint16(len(data)) m := o.ns.AllocMbuf(uint16(len(o.pktTemplate)) + dl) m.Append(o.pktTemplate) // template diff --git a/src/emu/plugins/transport_example/example.go b/src/emu/plugins/transport_example/example.go index ee09965..d2eb00d 100644 --- a/src/emu/plugins/transport_example/example.go +++ b/src/emu/plugins/transport_example/example.go @@ -25,6 +25,7 @@ const ( ) type TransEInit struct { + Network string `json:"network"` Addr string `json:"addr"` DataSize uint32 `json:"size"` Loops uint32 `json:"loops"` @@ -134,7 +135,7 @@ func NewTransportEClient(ctx *core.PluginCtx, initJson []byte) (*core.PluginBase nsplg := o.Ns.PluginCtx.GetOrCreate(TRANS_E_PLUG) o.tranNsPlug = nsplg.Ext.(*PluginTransportENs) - o.cfg = TransEInit{Addr: "48.0.0.1:80", DataSize: 10, Loops: 1} + o.cfg = TransEInit{Network: "tcp", Addr: "48.0.0.1:80", DataSize: 10, Loops: 1} ctx.Tctx.UnmarshalValidate(initJson, &o.cfg) o.b = make([]byte, o.cfg.DataSize) for i := 0; i < int(o.cfg.DataSize-1); i++ { @@ -222,11 +223,14 @@ func (o *PluginTransportEClient) OnEvent(msg string, a, b interface{}) { if resolvedIPv4 { // now we can dial o.ctx = transport.GetTransportCtx(o.Client) - s, err := o.ctx.Dial("tcp", o.cfg.Addr, o, nil, nil, 0) + s, err := o.ctx.Dial(o.cfg.Network, o.cfg.Addr, o, nil, nil, 0) if err != nil { return } o.s = s + if o.cfg.Network == "udp" { + o.startLoop() + } } } }